Merge branch 'master' into knowledge_base_1b_L2_ng_orth2feat_cain

This commit is contained in:
Daniel Povey 2022-05-19 12:42:50 +08:00
commit e44edf99a4
69 changed files with 4181 additions and 355 deletions

View File

@ -0,0 +1,15 @@
#!/usr/bin/env bash
# This script downloads the pre-computed fbank features for
# dev and test datasets of GigaSpeech.
#
# You will find directories `~/tmp/giga-dev-dataset-fbank` after running
# this script.
mkdir -p ~/tmp
cd ~/tmp
git lfs install
git clone https://huggingface.co/csukuangfj/giga-dev-dataset-fbank
ls -lh giga-dev-dataset-fbank/data/fbank

View File

@ -0,0 +1,49 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/gigaspeech/ASR
repo_url=https://huggingface.co/wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained-iter-3488000-avg-20.pt pruned_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh data/lang_bpe_500
ls -lh data/fbank
ls -lh pruned_transducer_stateless2/exp
log "Decoding dev and test"
# use a small value for decoding with CPU
max_duration=100
# Test only greedy_search to reduce CI running time
# for method in greedy_search fast_beam_search modified_beam_search; do
for method in greedy_search; do
log "Decoding with $method"
./pruned_transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless2/exp
done
rm pruned_transducer_stateless2/exp/*.pt
fi

View File

@ -0,0 +1,80 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
popd
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./pruned_transducer_stateless3/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./pruned_transducer_stateless3/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless3/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh pruned_transducer_stateless3/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless3/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless3/exp
done
rm pruned_transducer_stateless3/exp/*.pt
fi

View File

@ -0,0 +1,120 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-gigaspeech-2022-05-13
# stateless transducer + k2 pruned rnnt-loss + reworked conformer
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_gigaspeech_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Download GigaSpeech dev/test dataset
shell: bash
run: |
sudo apt-get install -y -q git-lfs
.github/scripts/download-gigaspeech-dev-test-dataset.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
ln -s ~/tmp/giga-dev-dataset-fbank/data egs/gigaspeech/ASR/
ls -lh egs/gigaspeech/ASR/data/fbank
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh
- name: Display decoding results for gigaspeech pruned_transducer_stateless2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/gigaspeech/ASR/
tree ./pruned_transducer_stateless2/exp
sudo apt-get -qq install tree
cd pruned_transducer_stateless2
echo "results for pruned_transducer_stateless2"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
- name: Upload decoding results for gigaspeech pruned_transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/

View File

@ -142,8 +142,8 @@ jobs:
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search===" echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Display decoding results for pruned_transducer_stateless3 - name: Display decoding results for pruned_transducer_stateless3
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
@ -161,8 +161,8 @@ jobs:
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search===" echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless2 - name: Upload decoding results for pruned_transducer_stateless2
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v2

View File

@ -0,0 +1,151 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-pruned-transducer-stateless3-2022-05-13
# stateless pruned transducer (reworked model) + giga speech
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
- name: Display decoding results for pruned_transducer_stateless3
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR
tree pruned_transducer_stateless3/exp
cd pruned_transducer_stateless3/exp
echo "===greedy search==="
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless3
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/

View File

@ -103,11 +103,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc cd egs/librispeech/ASR/conformer_ctc
pytest -v -s pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer cd ../transducer
pytest -v -s pytest -v -s
cd ../transducer_stateless cd ../transducer_stateless2
pytest -v -s pytest -v -s
cd ../transducer_lstm cd ../transducer_lstm
@ -128,11 +143,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc cd egs/librispeech/ASR/conformer_ctc
pytest -v -s pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer cd ../transducer
pytest -v -s pytest -v -s
cd ../transducer_stateless cd ../transducer_stateless2
pytest -v -s pytest -v -s
cd ../transducer_lstm cd ../transducer_lstm

View File

@ -107,7 +107,7 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 2.19 | 4.97 | | WER | 2.00 | 4.63 |
### Aishell ### Aishell
@ -200,19 +200,22 @@ We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless mod
### GigaSpeech ### GigaSpeech
We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
#### Conformer CTC #### Conformer CTC
| | Dev | Test | | | Dev | Test |
|-----|-------|-------| |-----|-------|-------|
| WER | 10.47 | 10.58 | | WER | 10.47 | 10.58 |
#### Pruned stateless RNN-T #### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
| | Dev | Test | | | Dev | Test |
|----------------------|-------|-------| |----------------------|-------|-------|
| greedy search | 10.59 | 10.87 | | greedy search | 10.51 | 10.73 |
| fast beam search | 10.56 | 10.80 | | fast beam search | 10.50 | 10.69 |
| modified beam search | 10.52 | 10.62 | | modified beam search | 10.40 | 10.51 |
## Deployment with C++ ## Deployment with C++
@ -238,6 +241,8 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc [TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc
[TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless [TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless
[TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless [TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless
[GigaSpeech_conformer_ctc]: egs/gigaspeech/ASR/conformer_ctc
[GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2
[yesno]: egs/yesno/ASR [yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR [librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR [aishell]: egs/aishell/ASR

View File

@ -16,6 +16,6 @@ ln -sfv /path/to/GigaSpeech download/GigaSpeech
| | Dev | Test | | | Dev | Test |
|--------------------------------|-------|-------| |--------------------------------|-------|-------|
| `conformer_ctc` | 10.47 | 10.58 | | `conformer_ctc` | 10.47 | 10.58 |
| `pruned_transducer_stateless2` | 10.52 | 10.62 | | `pruned_transducer_stateless2` | 10.40 | 10.51 |
See [RESULTS](/egs/gigaspeech/ASR/RESULTS.md) for details. See [RESULTS](/egs/gigaspeech/ASR/RESULTS.md) for details.

View File

@ -11,13 +11,15 @@ decoder contains only an embedding layer, a Conv1d (with kernel
size 2) and a linear layer (to transform tensor dim). k2 pruned size 2) and a linear layer (to transform tensor dim). k2 pruned
RNN-T loss is used. RNN-T loss is used.
The best WER, as of 2022-05-12, for the gigaspeech is below
Results are: Results are:
| | Dev | Test | | | Dev | Test |
|----------------------|-------|-------| |----------------------|-------|-------|
| greedy search | 10.59 | 10.87 | | greedy search | 10.51 | 10.73 |
| fast beam search | 10.56 | 10.80 | | fast beam search | 10.50 | 10.69 |
| modified beam search | 10.52 | 10.62 | | modified beam search | 10.40 | 10.51 |
To reproduce the above result, use the following commands for training: To reproduce the above result, use the following commands for training:
@ -39,33 +41,30 @@ and the following commands for decoding:
```bash ```bash
# greedy search # greedy search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 29 \ --iter 3488000 \
--avg 11 \ --avg 20 \
--decoding-method greedy_search \ --decoding-method greedy_search \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--max-duration 20 \ --max-duration 600
--num-workers 1
# fast beam search # fast beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 29 \ --iter 3488000 \
--avg 9 \ --avg 20 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--max-duration 20 \ --max-duration 600
--num-workers 1
# modified beam search # modified beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 29 \ --iter 3488000 \
--avg 8 \ --avg 15 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--max-duration 20 \ --max-duration 600
--num-workers 1
``` ```
Pretrained model is available at Pretrained model is available at

View File

@ -22,7 +22,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
@ -30,7 +30,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
@ -99,27 +99,28 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=29, default=29,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for decoding.
"Note: Epoch counts from 0.", Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
) )
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=8, default=8,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch' and '--iter'",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
) )
parser.add_argument( parser.add_argument(
@ -152,7 +153,7 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or frame. Used only when --decoding-method is beam_search or
modified_beam_search.""", modified_beam_search.""",
) )
@ -465,7 +466,11 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
@ -488,8 +493,9 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -497,8 +503,20 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg_last_n > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -51,7 +51,11 @@ import sentencepiece as spm
import torch import torch
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool from icefall.utils import str2bool
@ -64,8 +68,19 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for averaging.
"Note: Epoch counts from 0.", Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
) )
parser.add_argument( parser.add_argument(
@ -74,7 +89,7 @@ def get_parser():
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch' and '--iter'",
) )
parser.add_argument( parser.add_argument(
@ -141,7 +156,24 @@ def main():
model.to(device) model.to(device)
if params.avg == 1: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1

View File

@ -689,7 +689,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -831,10 +831,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
gigaspeech = GigaSpeechAsrDataModule(args) gigaspeech = GigaSpeechAsrDataModule(args)

View File

@ -1,6 +1,6 @@
## Results ## Results
### LibriSpeech BPE training results (Pruned Transducer 3) ### LibriSpeech BPE training results (Pruned Transducer 3, 2022-04-29)
[pruned_transducer_stateless3](./pruned_transducer_stateless3) [pruned_transducer_stateless3](./pruned_transducer_stateless3)
Same as `Pruned Transducer 2` but using the XL subset from Same as `Pruned Transducer 2` but using the XL subset from
@ -152,6 +152,67 @@ for epoch in 27; do
done done
``` ```
### LibriSpeech BPE training results (Pruned Transducer 3, 2022-05-13)
Same setup as [pruned_transducer_stateless3](./pruned_transducer_stateless3) (2022-04-29)
but change `--giga-prob` from 0.8 to 0.9. Also use `repeat` on gigaspeech XL
subset so that the gigaspeech dataloader never exhausts.
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|---------------------------------------------|
| greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 |
| modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 |
| fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 |
The training commands are:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./prepare.sh
./prepare_giga_speech.sh
./pruned_transducer_stateless3/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 0 \
--full-libri 1 \
--exp-dir pruned_transducer_stateless3/exp-0.9 \
--max-duration 300 \
--use-fp16 1 \
--lr-epochs 4 \
--num-workers 2 \
--giga-prob 0.9
```
The tensorboard log is available at
<https://tensorboard.dev/experiment/HpocR7dKS9KCQkJeYxfXug/>
Decoding commands:
```bash
for iter in 1224000; do
for avg in 14; do
for method in greedy_search modified_beam_search fast_beam_search ; do
./pruned_transducer_stateless3/decode.py \
--iter $iter \
--avg $avg \
--exp-dir ./pruned_transducer_stateless3/exp-0.9/ \
--max-duration 600 \
--decoding-method $method \
--max-sym-per-frame 1 \
--beam 4 \
--max-contexts 32
done
done
done
```
The pretrained models, training logs, decoding logs, and decoding results
can be found at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13>
### LibriSpeech BPE training results (Pruned Transducer 2) ### LibriSpeech BPE training results (Pruned Transducer 2)
[pruned_transducer_stateless2](./pruned_transducer_stateless2) [pruned_transducer_stateless2](./pruned_transducer_stateless2)

View File

@ -116,8 +116,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -159,6 +157,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -19,38 +19,38 @@ Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(2) beam search (2) beam search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`. You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
@ -233,6 +233,9 @@ def main():
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)

View File

@ -29,6 +29,7 @@ from decoder import Decoder
def test_decoder(): def test_decoder():
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
unk_id = 2
embedding_dim = 128 embedding_dim = 128
context_size = 4 context_size = 4
@ -36,6 +37,7 @@ def test_decoder():
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
unk_id=unk_id,
context_size=context_size, context_size=context_size,
) )
N = 100 N = 100

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -112,10 +112,13 @@ class Conformer(EncoderInterface):
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings(): # Caution: We assume the subsampling factor is 4!
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4! # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
lengths = ((x_lens - 1) // 2 - 1) // 2 #
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)

View File

@ -131,8 +131,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -191,6 +189,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -19,20 +19,38 @@ Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless2/pretrained.py \ ./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(1) beam search (2) beam search
./pruned_transducer_stateless2/pretrained.py \ ./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
@ -79,9 +97,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
help="""Path to bpe.model. help="""Path to bpe.model.""",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -117,7 +133,33 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search and modified_beam_search", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -244,9 +286,9 @@ def main():
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=8.0, beam=params.beam,
max_contexts=32, max_contexts=params.max_contexts,
max_states=8, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -254,6 +296,7 @@ def main():
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
@ -263,6 +306,7 @@ def main():
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())

View File

@ -212,7 +212,10 @@ class ScaledLinear(nn.Linear):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return None if self.bias is None else self.bias * self.bias_scale.exp() if self.bias is None or self.bias_scale is None:
return None
return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear( return torch.nn.functional.linear(
@ -255,7 +258,11 @@ class ScaledConv1d(nn.Conv1d):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return None if self.bias is None else self.bias * self.bias_scale.exp() bias = self.bias
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional F = torch.nn.functional
@ -269,7 +276,7 @@ class ScaledConv1d(nn.Conv1d):
self.get_weight(), self.get_weight(),
self.get_bias(), self.get_bias(),
self.stride, self.stride,
_single(0), (0,),
self.dilation, self.dilation,
self.groups, self.groups,
) )
@ -319,7 +326,12 @@ class ScaledConv2d(nn.Conv2d):
return self.weight * self.weight_scale.exp() return self.weight * self.weight_scale.exp()
def get_bias(self): def get_bias(self):
return None if self.bias is None else self.bias * self.bias_scale.exp() # see https://github.com/pytorch/pytorch/issues/24135
bias = self.bias
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
def _conv_forward(self, input, weight): def _conv_forward(self, input, weight):
F = torch.nn.functional F = torch.nn.functional
@ -333,7 +345,7 @@ class ScaledConv2d(nn.Conv2d):
weight, weight,
self.get_bias(), self.get_bias(),
self.stride, self.stride,
_pair(0), (0, 0),
self.dilation, self.dilation,
self.groups, self.groups,
) )
@ -398,6 +410,9 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
return x
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, x,
self.channel_dim, self.channel_dim,
@ -444,6 +459,8 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), """Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1). that we approximate closely with x * sigmoid(x-1).
""" """
if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless2/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -695,7 +695,7 @@ def train_one_epoch(
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -839,10 +839,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)

View File

@ -132,8 +132,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -192,6 +190,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -19,20 +19,38 @@ Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless3/pretrained.py \ ./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(1) beam search (2) beam search
./pruned_transducer_stateless3/pretrained.py \ ./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless3/exp/epoch-xx.pt`. You can also use `./pruned_transducer_stateless3/exp/epoch-xx.pt`.
@ -79,9 +97,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
help="""Path to bpe.model. help="""Path to bpe.model.""",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -117,7 +133,33 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search and modified_beam_search", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -244,9 +286,9 @@ def main():
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=8.0, beam=params.beam,
max_contexts=32, max_contexts=params.max_contexts,
max_states=8, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -254,6 +296,7 @@ def main():
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
@ -263,6 +306,7 @@ def main():
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless3/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless3/test_scaling.py
"""
import torch
from scaling import ActivationBalancer, ScaledConv1d, ScaledConv2d
def test_scaled_conv1d():
for bias in [True, False]:
conv1d = ScaledConv1d(
3,
6,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
torch.jit.script(conv1d)
def test_scaled_conv2d():
for bias in [True, False]:
conv2d = ScaledConv2d(
in_channels=1,
out_channels=3,
kernel_size=3,
padding=1,
bias=bias,
)
torch.jit.script(conv2d)
def test_activation_balancer():
act = ActivationBalancer(
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
torch.jit.script(act)
def main():
test_scaled_conv1d()
test_scaled_conv2d()
test_activation_balancer()
if __name__ == "__main__":
main()

View File

@ -767,7 +767,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -938,10 +938,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
@ -968,6 +965,7 @@ def run(rank, world_size, args):
train_giga_cuts = gigaspeech.train_S_cuts() train_giga_cuts = gigaspeech.train_S_cuts()
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
train_giga_cuts = train_giga_cuts.repeat(times=None)
if args.enable_musan: if args.enable_musan:
cuts_musan = load_manifest( cuts_musan = load_manifest(

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless4/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -724,7 +724,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if ( if (
@ -888,10 +888,7 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)

View File

@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface):
) )
if False: if False:
# It is commented out as DPP complains that not all parameters are # It is commented out as DDP complains that not all parameters are
# used. Need more checks later for the reason. # used. Need more checks later for the reason.
# #
# Caution: We assume the dataloader returns utterances with # Caution: We assume the dataloader returns utterances with
@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface):
) )
packed_rnn_out, _ = self.rnn(packed_x) packed_rnn_out, _ = self.rnn(packed_x)
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True) rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True)
else: else:
rnn_out, _ = self.rnn(x) rnn_out, _ = self.rnn(x)

View File

@ -97,8 +97,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id sos_y = add_sos(y, sos_id=blank_id)
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64) sos_y_padded = sos_y_padded.to(torch.int64)

View File

@ -109,10 +109,12 @@ class Conformer(Transformer):
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings(): # Caution: We assume the subsampling factor is 4!
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4! # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
lengths = ((x_lens - 1) // 2 - 1) // 2 #
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)

View File

@ -183,8 +183,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -226,6 +224,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -55,8 +57,8 @@ class Joiner(nn.Module):
N = encoder_out.size(0) N = encoder_out.size(0)
encoder_out_len = encoder_out_len.tolist() encoder_out_len: List[int] = encoder_out_len.tolist()
decoder_out_len = decoder_out_len.tolist() decoder_out_len: List[int] = decoder_out_len.tolist()
encoder_out_list = [ encoder_out_list = [
encoder_out[i, : encoder_out_len[i], :] for i in range(N) encoder_out[i, : encoder_out_len[i], :] for i in range(N)

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_stateless/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -523,7 +523,7 @@ def train_one_epoch(
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
@ -635,10 +635,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:

View File

@ -115,8 +115,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -158,6 +156,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -30,7 +32,8 @@ class Joiner(nn.Module):
self, self,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
decoder_out: torch.Tensor, decoder_out: torch.Tensor,
*unused, unused_encoder_out_len: Optional[torch.Tensor] = None,
unused_decoder_out_len: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -38,10 +41,12 @@ class Joiner(nn.Module):
Output from the encoder. Its shape is (N, T, self.input_dim). Output from the encoder. Its shape is (N, T, self.input_dim).
decoder_out: decoder_out:
Output from the decoder. Its shape is (N, U, self.input_dim). Output from the decoder. Its shape is (N, U, self.input_dim).
unused: unused_encoder_out_len:
This is a placeholder so that we can reuse This is a placeholder so that we can reuse
transducer_stateless/beam_search.py in this folder as that transducer_stateless/beam_search.py in this folder as that
script assumes the joiner networks accepts 4 inputs. script assumes the joiner networks accepts 4 inputs.
unused_decoder_out_len:
Just a placeholder.
Returns: Returns:
Return a tensor of shape (N, T, U, self.output_dim). Return a tensor of shape (N, T, U, self.output_dim).
""" """

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_stateless2/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -511,7 +511,7 @@ def train_one_epoch(
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 30:
return return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
@ -623,10 +623,7 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( diagnostic = diagnostics.attach_diagnostics(model)
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:

View File

@ -184,8 +184,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -229,6 +227,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -0,0 +1,32 @@
# SPGISpeech
SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective
transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in
length to allow easy training for speech recognition systems. Calls represent a broad
cross-section of international business English; SPGISpeech contains approximately 50,000
speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and
L2 English accents. The format of each WAV file is single channel, 16kHz, 16 bit audio.
Transcription text represents the output of several stages of manual post-processing.
As such, the text contains polished English orthography following a detailed style guide,
including proper casing, punctuation, and denormalized non-standard words such as numbers
and acronyms, making SPGISpeech suited for training fully formatted end-to-end models.
Official reference:
ONeill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam,
J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G.
(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted
end-to-end speech recognition. ArXiv, abs/2104.02014.
ArXiv link: https://arxiv.org/abs/2104.02014
## Performance Record
| Decoding method | val WER | val CER |
|---------------------------|------------|---------|
| greedy search | 2.40 | 0.99 |
| modified beam search | 2.24 | 0.91 |
| fast beam search | 2.35 | 0.97 |
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.

View File

@ -0,0 +1,73 @@
## Results
### SPGISpeech BPE training results (Pruned Transducer)
#### 2022-05-11
#### Conformer encoder + embedding decoder
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim).
The WERs are
| | dev | val | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
transcripts are orthographic or normalized. These WERs correspond to the normalized transcription
scenario.
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 20 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--max-duration 200 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25 \
--use-fp16 True
```
The decoding command is:
```
# greedy search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
# modified beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
# fast beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-spgispeech-pruned-transducer-stateless2>
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/ExSoBmrPRx6liMTGLu0Tgw/#scalars>

View File

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_hlg.py

View File

@ -0,0 +1,104 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
from pathlib import Path
import torch
from lhotse import LilcomChunkyWriter, CutSet, combine
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatMelOptions,
KaldifeatFrameOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
dataset_parts = (
"music",
"speech",
"noise",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None
musan_cuts_path = src_dir / "cuts_musan.jsonl.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_musan",
manifest_path=src_dir / f"cuts_musan.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

View File

@ -0,0 +1,145 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the SPGISpeech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
from pathlib import Path
from tqdm import tqdm
import torch
from lhotse import load_manifest_lazy, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatMelOptions,
KaldifeatFrameOptions,
)
from lhotse.manipulation import combine
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-splits",
type=int,
default=20,
help="Number of splits for the train set.",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start index of the train set split.",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop index of the train set split.",
)
parser.add_argument(
"--test",
action="store_true",
help="If set, only compute features for the dev and val set.",
)
parser.add_argument(
"--train",
action="store_true",
help="If set, only compute features for the train set.",
)
return parser.parse_args()
def compute_fbank_spgispeech(args):
assert args.train or args.test, "Either train or test must be set."
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
if args.train:
logging.info(f"Processing train")
cut_set = load_manifest_lazy(src_dir / f"cuts_train_raw.jsonl.gz")
chunk_size = len(cut_set) // args.num_splits
cut_sets = cut_set.split_lazy(
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
chunk_size=chunk_size,
)
start = args.start
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
num_digits = len(str(args.num_splits))
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
logging.info(f"Processing train split {i}")
cs = cut_sets[i].compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_train_{idx}",
manifest_path=src_dir / f"cuts_train_{idx}.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
if args.test:
for partition in ["dev", "val"]:
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{partition}",
manifest_path=src_dir / f"cuts_{partition}.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_spgispeech(args)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file splits the training set into train and dev sets.
"""
import logging
from pathlib import Path
import torch
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def split_spgispeech_train():
src_dir = Path("data/manifests")
manifests = read_manifests_if_cached(
dataset_parts=["train", "val"],
output_dir=src_dir,
prefix="spgispeech",
suffix="jsonl.gz",
lazy=True,
)
assert manifests is not None
train_dev_cuts = CutSet.from_manifests(
recordings=manifests["train"]["recordings"],
supervisions=manifests["train"]["supervisions"],
)
dev_cuts = train_dev_cuts.subset(first=4000)
train_cuts = train_dev_cuts.filter(lambda c: c not in dev_cuts)
# Add speed perturbation
train_cuts = (
train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
)
# Write the manifests to disk.
train_cuts.to_file(src_dir / "cuts_train_raw.jsonl.gz")
dev_cuts.to_file(src_dir / "cuts_dev_raw.jsonl.gz")
# Also write the val set to disk.
val_cuts = CutSet.from_manifests(
recordings=manifests["val"]["recordings"],
supervisions=manifests["val"]["supervisions"],
)
val_cuts.to_file(src_dir / "cuts_val_raw.jsonl.gz")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
split_spgispeech_train()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

196
egs/spgispeech/ASR/prepare.sh Executable file
View File

@ -0,0 +1,196 @@
#!/usr/bin/env bash
set -eou pipefail
nj=20
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/spgispeech
# You can find train.csv, val.csv, train, and val in this directory, which belong
# to the SPGISpeech dataset.
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
500
)
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/spgispeech,
# you can create a symlink
#
# ln -sfv /path/to/spgispeech $dl_dir/spgispeech
#
if [ ! -d $dl_dir/spgispeech/train.csv ]; then
lhotse download spgispeech $dl_dir
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare SPGISpeech manifest (may take ~1h)"
# We assume that you have downloaded the SPGISpeech corpus
# to $dl_dir/spgispeech. We perform text normalization for the transcripts.
mkdir -p data/manifests
lhotse prepare spgispeech -j $nj --normalize-text $dl_dir/spgispeech data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
lhotse combine data/manifests/recordings_{music,speech,noise}.json data/manifests/recordings_musan.jsonl.gz
lhotse cut simple -r data/manifests/recordings_musan.jsonl.gz data/manifests/cuts_musan_raw.jsonl.gz
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split train into train and dev and create cut sets."
python local/prepare_splits.py
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank features for spgispeech dev and val"
mkdir -p data/fbank
python local/compute_fbank_spgispeech.py --test
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank features for train"
mkdir -p data/fbank
python local/compute_fbank_spgispeech.py --train --num-splits 20
log "Combine features from train splits (may take ~1h)"
if [ ! -f data/manifests/cuts_train.jsonl.gz ]; then
pieces=$(find data/manifests -name "cuts_train_[0-9]*.jsonl.gz")
lhotse combine $pieces data/manifests/cuts_train.jsonl.gz
fi
gunzip -c data/manifests/train_cuts.jsonl.gz | shuf | gzip -c > data/manifests/train_cuts_shuf.jsonl.gz
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Compute fbank features for musan"
mkdir -p data/fbank
python local/compute_fbank_musan.py
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Dump transcripts for LM training"
mkdir -p data/lm
gunzip -c data/manifests/cuts_train_raw.jsonl.gz \
| jq '.supervisions[0].text' \
| sed 's:"::g' \
> data/lm/transcript_words.txt
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# Add special words to words.txt
echo "<eps> 0" > $lang_dir/words.txt
echo "!SIL 1" >> $lang_dir/words.txt
echo "[UNK] 2" >> $lang_dir/words.txt
# Add regular words to words.txt
gunzip -c data/manifests/cuts_train_raw.jsonl.gz \
| jq '.supervisions[0].text' \
| sed 's:"::g' \
| sed 's: :\n:g' \
| sort \
| uniq \
| sed '/^$/d' \
| awk '{print $0,NR+2}' \
>> $lang_dir/words.txt
# Add remaining special word symbols expected by LM scripts.
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "<s> ${num_words}" >> $lang_dir/words.txt
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "</s> ${num_words}" >> $lang_dir/words.txt
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "#0 ${num_words}" >> $lang_dir/words.txt
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript data/lm/transcript_words.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
fi
done
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Train LM"
lm_dir=data/lm
if [ ! -f $lm_dir/G.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lm_dir/transcript_words.txt \
-lm $lm_dir/G.arpa
fi
if [ ! -f $lm_dir/G_3_gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$lm_dir/G.arpa > $lm_dir/G_3_gram.fst.txt
fi
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
done
fi

View File

@ -0,0 +1,366 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from tqdm import tqdm
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class SPGISpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it "
"with training dataset. ",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--max-duration",
type=int,
default=100.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.jsonl.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
if self.args.on_the_fly_feats:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
)
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
)
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=self.args.num_buckets,
drop_last=True,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get SPGISpeech train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get SPGISpeech dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
@lru_cache()
def val_cuts(self) -> CutSet:
logging.info("About to get SPGISpeech val cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_val.jsonl.gz")
def test():
parser = argparse.ArgumentParser()
SPGISpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
adm = SPGISpeechAsrDataModule(args)
cuts = adm.train_cuts()
dl = adm.train_dataloaders(cuts)
for i, batch in tqdm(enumerate(dl)):
if i == 100:
break
cuts = adm.dev_cuts()
dl = adm.valid_dataloaders(cuts)
for i, batch in tqdm(enumerate(dl)):
if i == 100:
break
if __name__ == "__main__":
test()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py

View File

@ -0,0 +1,594 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
--iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import SPGISpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=20,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
test_set_cers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
wers_filename = (
params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(wers_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
# we also compute CER for spgispeech dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
cers_filename = (
params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(cers_filename, "w") as f:
cer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_cers[key] = cer
logging.info("Wrote detailed error stats to {}".format(wers_filename))
test_set_wers = {
k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
}
test_set_cers = {
k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
}
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER\tCER", file=f)
for key in test_set_wers:
print(
"{}\t{}\t{}".format(
key, test_set_wers[key], test_set_cers[key]
),
file=f,
)
s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key in test_set_wers:
s += "{}\t{}\t{}{}\n".format(
key, test_set_wers[key], test_set_cers[key], note
)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
SPGISpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
spgispeech = SPGISpeechAsrDataModule(args)
dev_cuts = spgispeech.dev_cuts()
val_cuts = spgispeech.val_cuts()
dev_dl = spgispeech.test_dataloaders(dev_cuts)
val_dl = spgispeech.test_dataloaders(val_cuts)
test_sets = ["dev", "val"]
test_dl = [dev_dl, val_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,201 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--avg-last-n 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless2/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/spgispeech/ASR
./pruned_transducer_stateless2/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py

File diff suppressed because it is too large Load Diff

1
egs/spgispeech/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -19,7 +19,7 @@
import random import random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -28,16 +28,12 @@ class TensorDiagnosticOptions(object):
"""Options object for tensor diagnostics: """Options object for tensor diagnostics:
Args: Args:
memory_limit:
The maximum number of bytes per tensor
(limits how many copies of the tensor we cache).
max_eig_dim: max_eig_dim:
The maximum dimension for which we print out eigenvalues The maximum dimension for which we print out eigenvalues
(limited for speed reasons). (limited for speed reasons).
""" """
def __init__(self, memory_limit: int = (2 ** 20), max_eig_dim: int = 512): def __init__(self, max_eig_dim: int = 512):
self.memory_limit = memory_limit
self.max_eig_dim = max_eig_dim self.max_eig_dim = max_eig_dim
def dim_is_summarized(self, size: int): def dim_is_summarized(self, size: int):
@ -94,138 +90,12 @@ def get_tensor_stats(
return x, count return x, count
def get_diagnostics_for_dim(
dim: int,
tensors: List[Tensor],
options: TensorDiagnosticOptions,
sizes_same: bool,
stats_type: str,
) -> str:
"""
This function gets diagnostics for a dimension of a module.
Args:
dim:
the dimension to analyze, with 0 <= dim < tensors[0].ndim
options:
options object
sizes_same:
True if all the tensor sizes are the same on this dimension
stats_type: either "abs" or "positive" or "eigs" or "value",
imdictates the type of stats we accumulate, abs is mean absolute
value, "positive" is proportion of positive to nonnegative values,
"eigs" is eigenvalues after doing outer product on this dim, sum
over all other dimes.
Returns:
Diagnostic as a string, either percentiles or the actual values,
see the code. Will return the empty string if the diagnostics did
not make sense to print out for this dimension, e.g. dimension
mismatch and stats_type == "eigs".
"""
# stats_and_counts is a list of pair (Tensor, int)
stats_and_counts = [get_tensor_stats(x, dim, stats_type) for x in tensors]
stats = [x[0] for x in stats_and_counts]
counts = [x[1] for x in stats_and_counts]
if stats_type == "eigs":
try:
stats = torch.stack(stats).sum(dim=0)
except: # noqa
return ""
count = sum(counts)
stats = stats / count
try:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same:
stats = torch.stack(stats).sum(dim=0)
count = sum(counts)
stats = stats / count
else:
stats = [x[0] / x[1] for x in stats_and_counts]
stats = torch.cat(stats, dim=0)
if stats_type == "rms":
stats = stats.sqrt()
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (not sizes_same) or options.dim_is_summarized(stats.numel())
if summarize:
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles)
ans = f"percentiles: [{percentiles}]"
else:
ans = stats.tolist()
ans = ["%.2g" % x for x in ans]
ans = "[" + " ".join(ans) + "]"
if stats_type == "value":
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}"
else:
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
return ans
def print_diagnostics_for_dim( @dataclass
name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions class TensorAndCount:
): tensor: Tensor
"""This function prints diagnostics for a dimension of a tensor. count: int
Args:
name:
The tensor name.
dim:
The dimension to analyze, with 0 <= dim < tensors[0].ndim.
tensors:
List of cached tensors to get the stats.
options:
Options object.
"""
ndim = tensors[0].ndim
if ndim > 1:
stats_types = ["abs", "positive", "value", "rms"]
if tensors[0].shape[dim] <= options.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
for stats_type in stats_types:
sizes = [x.shape[dim] for x in tensors]
sizes_same = all([x == sizes[0] for x in sizes])
s = get_diagnostics_for_dim(
dim, tensors, options, sizes_same, stats_type
)
if s == "":
continue
min_size = min(sizes)
max_size = max(sizes)
size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}"
# stats_type will be "abs" or "positive".
print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}")
class TensorDiagnostic(object): class TensorDiagnostic(object):
@ -238,12 +108,23 @@ class TensorDiagnostic(object):
name: name:
The tensor name. The tensor name.
""" """
def __init__(self, opts: TensorDiagnosticOptions, name: str): def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name self.name = name
self.opts = opts self.opts = opts
# A list to cache the tensors.
self.saved_tensors = []
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
# the keys into self.stats[dim] are strings, whose values can be
# "abs", "value", "positive", "rms", "value".
# The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
# containing a tensor and its associated count (which is the sum of the other dims
# that we aggregated over, e.g. the number of frames and/or batch elements and/or
# channels.
# ... we actually accumulate the Tensors / counts any time we have the same-dim tensor,
# only adding a new element to the list if there was a different dim.
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
def accumulate(self, x): def accumulate(self, x):
"""Accumulate tensors.""" """Accumulate tensors."""
@ -251,50 +132,115 @@ class TensorDiagnostic(object):
x = x[0] x = x[0]
if not isinstance(x, Tensor): if not isinstance(x, Tensor):
return return
if x.device == torch.device("cpu"): x = x.detach().clone()
x = x.detach().clone() if x.ndim == 0:
else: x = x.unsqueeze(0)
x = x.detach().to("cpu", non_blocking=True) ndim = x.ndim
self.saved_tensors.append(x) if self.stats is None:
num = len(self.saved_tensors) self.stats = [ dict() for _ in range(ndim) ]
if num & (num - 1) == 0: # power of 2..
self._limit_memory()
def _limit_memory(self): for dim in range(ndim):
"""Only keep the newly cached tensors to limit memory.""" this_dim_stats = self.stats[dim]
if len(self.saved_tensors) > 1024: if ndim > 1:
self.saved_tensors = self.saved_tensors[-1024:] stats_types = ["abs", "positive", "value", "rms"]
return if x.shape[dim] <= self.opts.max_eig_dim:
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
this_dict = self.stats[dim]
for stats_type in stats_types:
stats, count = get_tensor_stats(x, dim, stats_type)
if not stats_type in this_dim_stats:
this_dim_stats[stats_type] = [] # list of TensorAndCount
done = False
if this_dim_stats[stats_type] is None:
# we can reach here if we detected for stats_type "eigs" that
# where was more than one different size for this dim. Then we
# disable accumulating this stats type, as it uses too much memory.
continue
for s in this_dim_stats[stats_type]:
if s.tensor.shape == stats.shape:
s.tensor += stats
s.count += count
done = True
break
if not done:
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
# don't accumulat "eigs" stats type, it uses too much memory
this_dim_stats[stats_type] = None
else:
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
tot_mem = 0.0
for i in reversed(range(len(self.saved_tensors))):
tot_mem += (
self.saved_tensors[i].numel()
* self.saved_tensors[i].element_size()
)
if tot_mem > self.opts.memory_limit:
self.saved_tensors = self.saved_tensors[i:]
return
def print_diagnostics(self): def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor.""" """Print diagnostics for each dimension of the tensor."""
if len(self.saved_tensors) == 0: for dim, this_dim_stats in enumerate(self.stats):
print("{name}: no stats".format(name=self.name)) for stats_type, stats_list in this_dim_stats.items():
return # stats_type could be "rms", "value", "abs", "eigs", "positive".
# "value" could be a list of TensorAndCount, or None
if stats_list is None:
assert stats_type == "eigs"
continue
if self.saved_tensors[0].ndim == 0: if stats_type == "eigs":
# Ensure there is at least one dim. assert len(stats_list) == 1
self.saved_tensors = [x.unsqueeze(0) for x in self.saved_tensors] stats = stats_list[0].tensor / stats_list[0].count
try:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif len(stats_list) == 1:
stats = stats_list[0].tensor / stats_list[0].count
else:
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0)
try: if stats_type == "rms":
device = torch.device("cuda") # we stored the square; after aggregation we need to take sqrt.
except: # noqa stats = stats.sqrt()
device = torch.device("cpu")
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel())
if summarize: # usually `summarize` will be true
# print out percentiles.
stats = stats.sort()[0]
num_percentiles = 10
size = stats.numel()
percentiles = []
for i in range(num_percentiles + 1):
index = (i * (size - 1)) // num_percentiles
percentiles.append(stats[index].item())
percentiles = ["%.2g" % x for x in percentiles]
percentiles = " ".join(percentiles)
ans = f"percentiles: [{percentiles}]"
else:
ans = stats.tolist()
ans = ["%.2g" % x for x in ans]
ans = "[" + " ".join(ans) + "]"
if stats_type == "value":
# This norm is useful because it is strictly less than the largest
# sqrt(eigenvalue) of the variance, which we print out, and shows,
# speaking in an approximate way, how much of that largest eigenvalue
# can be attributed to the mean of the distribution.
norm = (stats ** 2).sum().sqrt().item()
ans += f", norm={norm:.2g}"
mean = stats.mean().item()
rms = (stats ** 2).mean().sqrt().item()
ans += f", mean={mean:.2g}, rms={rms:.2g}"
# OK, "ans" contains the actual stats, e.g.
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
sizes = [x.tensor.shape[0] for x in stats_list]
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}")
ndim = self.saved_tensors[0].ndim
tensors = [x.to(device) for x in self.saved_tensors]
for dim in range(ndim):
print_diagnostics_for_dim(self.name, dim, tensors, self.opts)
class ModelDiagnostic(object): class ModelDiagnostic(object):