Merge remote-tracking branch 'k2-fsa/master' into emformer_conv_simplify_new

This commit is contained in:
yaozengwei 2022-07-05 15:47:41 +08:00
commit dbea9a9970
183 changed files with 23792 additions and 721 deletions

View File

@ -0,0 +1,86 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/aishell/ASR
git lfs install
fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests
log "Downloading pre-commputed fbank from $fbank_url"
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
ln -s $PWD/aishell-test-dev-manifests/data .
log "Downloading pre-trained model from $repo_url"
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-epoch-29-avg-5-torch-1.10.pt pretrained.pt
popd
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./pruned_transducer_stateless3/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$rep/test_wavs/BAC009S0764W0123.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./pruned_transducer_stateless3/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$rep/test_wavs/BAC009S0764W0123.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless3/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_char data/
ls -lh data
ls -lh pruned_transducer_stateless3/exp
log "Decoding test and dev"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless3/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless3/exp
done
rm pruned_transducer_stateless3/exp/*.pt
fi

View File

@ -32,6 +32,12 @@ for sym in 1 2 3; do
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -0,0 +1,86 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-epoch-24-avg-10.pt pretrained.pt
popd
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./pruned_transducer_stateless2/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--simulate-streaming 1 \
--causal-convolution 1 \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./pruned_transducer_stateless2/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--simulate-streaming 1 \
--causal-convolution 1 \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained-epoch-24-avg-10.pt pruned_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh pruned_transducer_stateless2/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless2/exp \
--simulate-streaming 1 \
--causal-convolution 1
done
rm pruned_transducer_stateless2/exp/*.pt
fi

View File

@ -0,0 +1,119 @@
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-aishell-2022-06-20
# pruned RNN-T + reworked model with random combiner
# https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_aishell_2022_06_20:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
- name: Display decoding results for aishell pruned_transducer_stateless3
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/aishell/ASR/
tree ./pruned_transducer_stateless3/exp
cd pruned_transducer_stateless3
echo "results for pruned_transducer_stateless3"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
- name: Upload decoding results for aishell pruned_transducer_stateless3
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-06-20
path: egs/aishell/ASR/pruned_transducer_stateless3/exp/

View File

@ -0,0 +1,155 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-streaming-2022-06-26
# streaming conformer stateless transducer2
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_streaming_2022_06_26:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other-v2
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
- name: Display decoding results
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./pruned_transducer_stateless2/exp
cd pruned_transducer_stateless2
echo "results for pruned_transducer_stateless2"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified_beam_search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-06-26
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/

View File

@ -33,13 +33,13 @@ jobs:
# disable macOS test for now.
os: [ubuntu-18.04]
python-version: [3.7, 3.8]
torch: ["1.8.0", "1.10.0"]
torchaudio: ["0.8.0", "0.10.0"]
k2-version: ["1.9.dev20211101"]
torch: ["1.8.0", "1.11.0"]
torchaudio: ["0.8.0", "0.11.0"]
k2-version: ["1.15.1.dev20220427"]
exclude:
- torch: "1.8.0"
torchaudio: "0.10.0"
- torch: "1.10.0"
torchaudio: "0.11.0"
- torch: "1.11.0"
torchaudio: "0.8.0"
fail-fast: false
@ -67,7 +67,7 @@ jobs:
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
else
pip install torchaudio==${{ matrix.torchaudio }}

View File

@ -31,6 +31,8 @@ We provide the following recipes:
- [Aidatatang_200zh][aidatatang_200zh]
- [WenetSpeech][wenetspeech]
- [Alimeeting][alimeeting]
- [Aishell4][aishell4]
- [TAL_CSASR][tal_csasr]
### yesno
@ -270,6 +272,36 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
### Aishell4
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
The best CER(%) results:
| | test |
|----------------------|--------|
| greedy search | 29.89 |
| fast beam search | 28.91 |
| modified beam search | 29.08 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
### TAL_CSASR
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|--|--|--|--|--|--|--|
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
## Deployment with C++
Once you have trained a model in icefall, you may want to deploy it with C++,
@ -298,6 +330,8 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2
[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR
@ -307,5 +341,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[aidatatang_200zh]: egs/aidatatang_200zh/ASR
[wenetspeech]: egs/wenetspeech/ASR
[alimeeting]: egs/alimeeting/ASR
[aishell4]: egs/aishell4/ASR
[tal_csasr]: egs/tal_csasr/ASR
[k2]: https://github.com/k2-fsa/k2
)

View File

@ -114,8 +114,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -155,6 +153,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -4,6 +4,8 @@
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell/index.html>
for how to run models in this recipe.
# Transducers
There are various folders containing the name `transducer` in this folder.
@ -14,6 +16,7 @@ The following table lists the differences among them.
| `transducer_stateless` | Conformer | Embedding + Conv1d | with `k2.rnnt_loss` |
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1,10 +1,93 @@
## Results
### Aishell training result(Transducer-stateless)
### Aishell training result(Stateless Transducer)
#### Pruned transducer stateless 3
See <https://github.com/k2-fsa/icefall/pull/436>
[./pruned_transducer_stateless3](./pruned_transducer_stateless3)
It uses pruned RNN-T.
| | test | dev | comment |
|------------------------|------|------|---------------------------------------|
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
Training command is:
```bash
./prepare.sh
./prepare_aidatatang_200zh.sh
export CUDA_VISIBLE_DEVICES="4,5,6,7"
./pruned_transducer_stateless3/train.py \
--exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \
--world-size 4 \
--max-duration 200 \
--datatang-prob 0.5 \
--start-epoch 1 \
--num-epochs 30 \
--use-fp16 1 \
--num-encoder-layers 12 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--context-size 1 \
--decoder-dim 512 \
--joiner-dim 512 \
--master-port 12356
```
**Caution**: It uses `--context-size=1`.
The tensorboard log is available at
<https://tensorboard.dev/experiment/OKKacljwR6ik7rbDr5gMqQ>
The decoding command is:
```bash
for epoch in 29; do
for avg in 5; do
for m in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless3/decode.py \
--exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--max-duration 600 \
--decoding-method $m \
--num-encoder-layers 12 \
--dim-feedforward 2048 \
--nhead 8 \
--context-size 1 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
done
done
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20>
We have a tutorial in [sherpa](https://github.com/k2-fsa/sherpa) about how
to use the pre-trained model for non-streaming ASR. See
<https://k2-fsa.github.io/sherpa/offline_asr/conformer/aishell.html>
#### 2022-03-01
[./transducer_stateless_modified-2](./transducer_stateless_modified-2)
It uses [optimized_transducer](https://github.com/csukuangfj/optimized_transducer)
for computing RNN-T loss.
Stateless transducer + modified transducer + using [aidatatang_200zh](http://www.openslr.org/62/) as extra training data.

View File

@ -364,7 +364,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)

View File

@ -364,7 +364,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)

View File

@ -18,7 +18,7 @@ stop_stage=10
# This directory contains the language model downloaded from
# https://huggingface.co/pkufool/aishell_lm
#
# - 3-gram.unpruned.apra
# - 3-gram.unpruned.arpa
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from

View File

@ -0,0 +1 @@
../transducer_stateless_modified-2/aidatatang_200zh.py

View File

@ -0,0 +1 @@
../transducer_stateless_modified-2/aishell.py

View File

@ -0,0 +1 @@
../transducer_stateless_modified-2/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py

View File

@ -0,0 +1,637 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from aishell import AIShell
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
token_table: k2.SymbolTable,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
token_table:
It maps token ID to a string.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_tokens.append(hyp)
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
token_table:
It maps a token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
token_table=token_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
asr_datamodule = AsrDataModule(args)
aishell = AIShell(manifest_dir=args.manifest_dir)
test_cuts = aishell.test_cuts()
dev_cuts = aishell.valid_cuts()
test_dl = asr_datamodule.test_dataloaders(test_cuts)
dev_dl = asr_datamodule.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_dls = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
token_table=lexicon.token_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1

View File

@ -0,0 +1,277 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--jit 0 \
--epoch 29 \
--avg 5
It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt
To use the generated file with `pruned_transducer_stateless3/decode.py`,
you can do::
cd /path/to/exp_dir
ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt
cd /path/to/egs/aishell/ASR
./pruned_transducer_stateless3/decode.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=29,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=Path,
default=Path("pruned_transducer_stateless3/exp"),
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,236 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
decoder_datatang: Optional[nn.Module] = None,
joiner_datatang: Optional[nn.Module] = None,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim). Its output shape is (N, T, U, vocab_size).
Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
encoder_dim:
Output dimension of the encoder network.
decoder_dim:
Output dimension of the decoder network.
joiner_dim:
Input dimension of the joiner network.
vocab_size:
Output dimension of the joiner network.
decoder_datatang:
Optional. The decoder network for the aidatatang_200zh dataset.
joiner_datatang:
Optional. The joiner network for the aidatatang_200zh dataset.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_datatang is not None:
self.simple_am_proj_datatang = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj_datatang = ScaledLinear(decoder_dim, vocab_size)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
aishell: bool = True,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
aishell:
True to use the decoder and joiner for the aishell dataset.
False to use the decoder and joiner for the aidatatang_200zh
dataset.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
assert torch.all(encoder_out_lens > 0)
if aishell:
decoder = self.decoder
simple_lm_proj = self.simple_lm_proj
simple_am_proj = self.simple_am_proj
joiner = self.joiner
else:
decoder = self.decoder_datatang
simple_lm_proj = self.simple_lm_proj_datatang
simple_am_proj = self.simple_am_proj_datatang
joiner = self.joiner_datatang
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=joiner.encoder_proj(encoder_out),
lm=joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,337 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
logging.info(f"Using {params.method}")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -362,7 +362,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)

View File

@ -184,8 +184,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -225,6 +223,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -182,8 +182,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
def main():
args = get_parser().parse_args()
assert args.jit is False, "torchscript support will be added later"
params = get_params()
params.update(vars(args))
@ -223,6 +221,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -405,7 +405,7 @@ def compute_loss(
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Compute RNN-T loss given the model and its inputs.
Args:
params:

View File

@ -182,8 +182,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
def main():
args = get_parser().parse_args()
assert args.jit is False, "torchscript support will be added later"
params = get_params()
params.update(vars(args))
@ -223,6 +221,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -0,0 +1,19 @@
# Introduction
This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets).
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

117
egs/aishell4/ASR/RESULTS.md Normal file
View File

@ -0,0 +1,117 @@
## Results
### Aishell4 Char training results (Pruned Transducer Stateless5)
#### 2022-06-13
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/399.
When use-averaged-model=False, the CERs are
| | test | comment |
|------------------------------------|------------|------------------------------------------|
| greedy search | 30.05 | --epoch 30, --avg 25, --max-duration 800 |
| modified beam search (beam size 4) | 29.16 | --epoch 30, --avg 25, --max-duration 800 |
| fast beam search (set as default) | 29.20 | --epoch 30, --avg 25, --max-duration 1500|
When use-averaged-model=True, the CERs are
| | test | comment |
|------------------------------------|------------|----------------------------------------------------------------------|
| greedy search | 29.89 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
| modified beam search (beam size 4) | 28.91 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
| fast beam search (set as default) | 29.08 | --iter 36000, --avg 8, --max-duration 1500 --use-averaged-model=True |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 220 \
--save-every-n 4000
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/tjaVRKERS8C10SzhpBcxSQ/#scalars
When use-averaged-model=False, the decoding command is:
```
epoch=30
avg=25
## greedy search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800
## modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4
## fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
When use-averaged-model=True, the decoding command is:
```
iter=36000
avg=8
## greedy search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--use-averaged-model True
## modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4 \
--use-averaged-model True
## fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--use-averaged-model True
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aishell4_pruned_transducer_stateless5>

View File

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the aidatatang_200zh dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell4(num_mel_bins: int = 80):
src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"train_S",
"train_M",
"train_L",
"test",
)
prefix = "aishell4"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=ChunkedLilcomHdf5Writer,
)
logging.info("About splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
min_duration=None,
)
cut_set.to_file(output_dir / cuts_filename)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aishell4(num_mel_bins=args.num_mel_bins)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,113 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()`
in ../../../librispeech/ASR/transducer/train.py
for usage.
"""
from lhotse import load_manifest
def main():
paths = [
"./data/fbank/cuts_train_S.json.gz",
"./data/fbank/cuts_train_M.json.gz",
"./data/fbank/cuts_train_L.json.gz",
"./data/fbank/cuts_test.json.gz",
]
for path in paths:
print(f"Starting display the statistics for {path}")
cuts = load_manifest(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
Starting display the statistics for ./data/fbank/cuts_train_S.json.gz
Cuts count: 91995
Total duration (hours): 95.8
Speech duration (hours): 95.8 (100.0%)
***
Duration statistics (seconds):
mean 3.7
std 7.1
min 0.1
25% 0.9
50% 2.5
75% 5.4
99% 15.3
99.5% 17.5
99.9% 23.3
max 1021.7
Starting display the statistics for ./data/fbank/cuts_train_M.json.gz
Cuts count: 177195
Total duration (hours): 179.5
Speech duration (hours): 179.5 (100.0%)
***
Duration statistics (seconds):
mean 3.6
std 6.4
min 0.0
25% 0.9
50% 2.4
75% 5.2
99% 14.9
99.5% 17.0
99.9% 23.5
max 990.4
Starting display the statistics for ./data/fbank/cuts_train_L.json.gz
Cuts count: 37572
Total duration (hours): 49.1
Speech duration (hours): 49.1 (100.0%)
***
Duration statistics (seconds):
mean 4.7
std 4.0
min 0.2
25% 1.6
50% 3.7
75% 6.7
99% 17.5
99.5% 19.8
99.9% 26.2
max 87.4
Starting display the statistics for ./data/fbank/cuts_test.json.gz
Cuts count: 10574
Total duration (hours): 12.1
Speech duration (hours): 12.1 (100.0%)
***
Duration statistics (seconds):
mean 4.1
std 3.4
min 0.2
25% 1.4
50% 3.2
75% 5.8
99% 14.4
99.5% 14.9
99.9% 16.5
max 17.9
"""

View File

@ -0,0 +1,248 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/text,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import re
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for word in words:
chars = list(word.strip(" \t"))
if contain_oov(token_sym_table, chars):
continue
lexicon.append((word, chars))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(text_file: str) -> Dict[str, int]:
"""Generate tokens from the given text file.
Args:
text_file:
A file that contains text lines to generate tokens.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
whitespace = re.compile(r"([ \t\r\n]+)")
with open(text_file, "r", encoding="utf-8") as f:
for line in f:
line = re.sub(whitespace, "", line)
chars = list(line)
for char in chars:
if char not in tokens:
tokens[char] = len(tokens)
return tokens
def main():
lang_dir = Path("data/lang_char")
text_file = lang_dir / "text"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
token_sym_table = generate_tokens(text_file)
lexicon = generate_lexicon(token_sym_table, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,390 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [token2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir", type=str, help="The lang dir, data/lang_phone"
)
return parser.parse_args()
def main():
out_dir = Path(get_args().lang_dir)
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
assert "<eps>" not in tokens
tokens = ["<eps>"] + tokens
assert "<eps>" not in words
assert "#0" not in words
assert "<s>" not in words
assert "</s>" not in words
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / "tokens.txt", token2id)
write_mapping(out_dir / "words.txt", word2id)
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input words.txt without ids:
- words_no_ids.txt
and generates the new words.txt with related ids.
- words.txt
"""
import argparse
import logging
from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
description="Prepare words.txt",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file",
default="data/lang_char/words_no_ids.txt",
type=str,
help="the words file without ids for WenetSpeech",
)
parser.add_argument(
"--output-file",
default="data/lang_char/words.txt",
type=str,
help="the words file with ids for WenetSpeech",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
add_words = ["<eps> 0", "!SIL 1", "<SPOKEN_NOISE> 2", "<UNK> 3"]
new_lines.extend(add_words)
logging.info("Starting reading the input file")
for i in tqdm(range(len(lines))):
x = lines[i]
idx = 4 + i
new_line = str(x.strip("\n")) + " " + str(idx)
new_lines.append(new_line)
logging.info("Starting writing the words.txt")
f_out = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_out.write(line)
f_out.write("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,106 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import os
import tempfile
import k2
from prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,
get_words,
lexicon_to_fst,
read_lexicon,
write_lexicon,
write_mapping,
)
def generate_lexicon_file() -> str:
fd, filename = tempfile.mkstemp()
os.close(fd)
s = """
!SIL SIL
<SPOKEN_NOISE> SPN
<UNK> SPN
f f
a a
foo f o o
bar b a r
bark b a r k
food f o o d
food2 f o o d
fo f o
""".strip()
with open(filename, "w") as f:
f.write(s)
return filename
def test_read_lexicon(filename: str):
lexicon = read_lexicon(filename)
phones = get_phones(lexicon)
words = get_words(lexicon)
print(lexicon)
print(phones)
print(words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
print(lexicon_disambig)
print("max disambig:", f"#{max_disambig}")
phones = ["<eps>", "SIL", "SPN"] + phones
for i in range(max_disambig + 1):
phones.append(f"#{i}")
words = ["<eps>"] + words
phone2id = generate_id_map(phones)
word2id = generate_id_map(words)
print(phone2id)
print(word2id)
write_mapping("phones.txt", phone2id)
write_mapping("words.txt", word2id)
write_lexicon("a.txt", lexicon)
write_lexicon("a_disambig.txt", lexicon_disambig)
fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
def main():
filename = generate_lexicon_file()
test_read_lexicon(filename)
os.remove(filename)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input "text", which refers to the transcript file for
WenetSpeech:
- text
and generates the output file text_word_segmentation which is implemented
with word segmenting:
- text_words_segmentation
"""
import argparse
import jieba
from tqdm import tqdm
jieba.enable_paddle()
def get_parser():
parser = argparse.ArgumentParser(
description="Chinese Word Segmentation for text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file",
default="data/lang_char/text",
type=str,
help="the input text file for WenetSpeech",
)
parser.add_argument(
"--output-file",
default="data/lang_char/text_words_segmentation",
type=str,
help="the text implemented with words segmenting for WenetSpeech",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
for i in tqdm(range(len(lines))):
x = lines[i].rstrip()
seg_list = jieba.cut(x, use_paddle=True)
new_line = " ".join(seg_list)
new_lines.append(new_line)
f_new = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_new.write(line)
f_new.write("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,195 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe)
# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import codecs
import re
import sys
from typing import List
from pypinyin import lazy_pinyin, pinyin
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--nchar",
"-n",
default=1,
type=int,
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument(
"--space", default="<space>", type=str, help="space symbol"
)
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument(
"text", type=str, default=False, nargs="?", help="input text"
)
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "pinyin", "lazy_pinyin"],
help="""Transcript type. char/pinyin/lazy_pinyin""",
)
return parser
def token2id(
texts, token_table, token_type: str = "lazy_pinyin", oov: str = "<unk>"
) -> List[List[int]]:
"""Convert token to id.
Args:
texts:
The input texts, it refers to the chinese text here.
token_table:
The token table is built based on "data/lang_xxx/token.txt"
token_type:
The type of token, such as "pinyin" and "lazy_pinyin".
oov:
Out of vocabulary token. When a word(token) in the transcript
does not exist in the token list, it is replaced with `oov`.
Returns:
The list of ids for the input texts.
"""
if texts is None:
raise ValueError("texts can't be None!")
else:
oov_id = token_table[oov]
ids: List[List[int]] = []
for text in texts:
chars_list = list(str(text))
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
token_table[txt] if txt in token_table else oov_id
for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
text = pinyin(chars_list)
sub_ids = [
token_table[txt[0]] if txt[0] in token_table else oov_id
for txt in text
]
ids.append(sub_ids)
return ids
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
line = f.readline()
n = args.nchar
while line:
x = line.split()
print(" ".join(x[: args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols :]) # noqa E203
# get all matched positions
match_pos = []
for r in rs:
i = 0
while i >= 0:
m = r.search(a, i)
if m:
match_pos.append([m.start(), m.end()])
i = m.end()
else:
break
if len(match_pos) > 0:
chars = []
i = 0
while i < len(a):
start_pos, end_pos = exist_or_not(i, match_pos)
if start_pos is not None:
chars.append(a[start_pos:end_pos])
i = end_pos
else:
chars.append(a[i])
i += 1
a = chars
if args.trans_type == "pinyin":
a = pinyin(list(str(a)))
a = [one[0] for one in a]
if args.trans_type == "lazy_pinyin":
a = lazy_pinyin(list(str(a)))
a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203
a_flat = []
for z in a:
a_flat.append("".join(z))
a_chars = "".join(a_flat)
print(a_chars)
line = f.readline()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,119 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input "text_full", which includes three transcript files
(train_S, train_M and train_L) for AISHELL4:
- text_full
and generates the output file text_normalize which is implemented
to normalize text:
- text
"""
import argparse
from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
description="Normalizing for text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input",
default="data/lang_char/text_full",
type=str,
help="the input text files for AISHELL4",
)
parser.add_argument(
"--output",
default="data/lang_char/text",
type=str,
help="the text implemented with normalizer for AISHELL4",
)
return parser
def text_normalize(str_line: str):
line = str_line.strip().rstrip("\n")
line = line.replace(" ", "")
line = line.replace("<sil>", "")
line = line.replace("<%>", "")
line = line.replace("<->", "")
line = line.replace("<$>", "")
line = line.replace("<#>", "")
line = line.replace("<_>", "")
line = line.replace("<space>", "")
line = line.replace("`", "")
line = line.replace("&", "")
line = line.replace(",", "")
line = line.replace("", "")
line = line.replace("", "A")
line = line.replace("", "B")
line = line.replace("", "C")
line = line.replace("", "K")
line = line.replace("", "T")
line = line.replace("", "")
line = line.replace("", "")
line = line.replace("", "")
line = line.replace("", "")
line = line.replace("", "")
line = line.replace("·", "")
line = line.replace("*", "")
line = line.replace("", "")
line = line.replace("$", "")
line = line.replace("+", "")
line = line.replace("-", "")
line = line.replace("\\", "")
line = line.replace("?", "")
line = line.replace("", "")
line = line.replace("%", "")
line = line.replace(".", "")
line = line.replace("<", "")
line = line.replace("", "")
line = line.upper()
return line
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input
output_file = args.output
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
for i in tqdm(range(len(lines))):
new_line = text_normalize(lines[i])
new_lines.append(new_line)
f_new = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_new.write(line)
f_new.write("\n")
if __name__ == "__main__":
main()

160
egs/aishell4/ASR/prepare.sh Executable file
View File

@ -0,0 +1,160 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/aishell4
# You can find four directories:train_S, train_M, train_L and test.
# You can download it from https://openslr.org/111/
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/aishell4,
# you can create a symlink
#
# ln -sfv /path/to/aishell4 $dl_dir/aishell4
#
if [ ! -f $dl_dir/aishell4/train_L ]; then
lhotse download aishell4 $dl_dir/aishell4
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/musan
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare aishell4 manifest"
# We assume that you have downloaded the aishell4 corpus
# to $dl_dir/aishell4
if [ ! -f data/manifests/aishell4/.manifests.done ]; then
mkdir -p data/manifests/aishell4
lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4
touch data/manifests/aishell4/.manifests.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank/aishell4
lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4
touch data/fbank/aishell4/.fbank.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
if [ ! -f data/manifests/.musan_manifests.done ]; then
log "It may take 6 minutes"
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_manifests.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
mkdir -p data/fbank
./local/compute_fbank_musan.py
touch data/fbank/.msuan.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for aishell4"
if [ ! -f data/fbank/.aishell4.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py
touch data/fbank/.aishell4.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir
# Prepare text.
# Note: in Linux, you can install jq with the following command:
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_S
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_M
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_L
for r in text_S text_M text_L ; do
cat $lang_char_dir/$r >> $lang_char_dir/text_full
done
# Prepare text normalize
python ./local/text_normalize.py \
--input $lang_char_dir/text_full \
--output $lang_char_dir/text
# Prepare words segments
python ./local/text2segments.py \
--input $lang_char_dir/text \
--output $lang_char_dir/text_words_segmentation
cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \
| sort -u | sed "/^$/d" \
| uniq > $lang_char_dir/words_no_ids.txt
# Prepare words.txt
if [ ! -f $lang_char_dir/words.txt ]; then
./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt
fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py
fi
fi

View File

@ -0,0 +1,448 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class Aishell4AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=300,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=30000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_dl.sampler.load_state_dict(sampler_state_dict)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_S_cuts(self) -> CutSet:
logging.info("About to get S train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_S.jsonl.gz"
)
@lru_cache()
def train_M_cuts(self) -> CutSet:
logging.info("About to get M train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_M.jsonl.gz"
)
@lru_cache()
def train_L_cuts(self) -> CutSet:
logging.info("About to get L train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_train_L.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
# Aishell4 doesn't have dev data, here use test to replace dev.
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,630 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
When use-averaged-model=True, usage:
(1) greedy search
./pruned_transducer_stateless5/decode.py \
--iter 36000 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 800 \
--decoding-method greedy_search \
--use-averaged-model True
(2) modified beam search
./pruned_transducer_stateless5/decode.py \
--iter 36000 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4 \
--use-averaged-model True
(3) fast beam search
./pruned_transducer_stateless5/decode.py \
--iter 36000 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 800 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--use-averaged-model True
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import Aishell4AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from local.text_normalize import text_normalize
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
Aishell4AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
return c
aishell4 = Aishell4AsrDataModule(args)
test_cuts = aishell4.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)
test_dl = aishell4.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dl = [test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,278 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless5/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/aishell4/ASR
./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
model.to("cpu")
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,358 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
When use-averaged-model=True, usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir data/lang_char \
--decoding-method greedy_search \
--use-averaged-model True \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir data/lang_char \
--use-averaged-model True \
--decoding-method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search (not suggest)
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir data/lang_char \
--use-averaged-model True \
--decoding-method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--lang-dir data/lang_char \
--use-averaged-model True \
--decoding-method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
./pruned_transducer_stateless5/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Path to lang.
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--decoding-method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.decoding_method}"
if params.decoding_method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding-method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/aishell4/ASR
python ./pruned_transducer_stateless5/test_model.py
"""
from train import get_params, get_transducer_model
def test_model_1():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 24
params.dim_feedforward = 1536 # 384 * 4
params.encoder_dim = 384
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
def test_model_M():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 18
params.dim_feedforward = 1024
params.encoder_dim = 256
params.nhead = 4
params.decoder_dim = 512
params.joiner_dim = 512
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
# test_model_1()
test_model_M()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

1
egs/aishell4/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../egs/aishell/ASR/shared

View File

@ -114,8 +114,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -155,6 +153,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)

View File

@ -131,8 +131,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -191,6 +189,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -1,5 +1,312 @@
## Results
### LibriSpeech BPE training results (Pruned Stateless Streaming Conformer RNN-T)
#### [pruned_transducer_stateless](./pruned_transducer_stateless)
See <https://github.com/k2-fsa/icefall/pull/380> for more details.
##### Training on full librispeech
The WERs are (the number in the table formatted as test-clean & test-other):
We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|----------------------|--------------|----------------|----------------|----------------|----------------|
| greedy search | 32 | 4.74 & 11.38 | 4.57 & 10.86 | 4.18 & 10.37 | 3.87 & 9.85 |
| greedy search | 64 | 4.74 & 11.25 | 4.48 & 10.72 | 4.1 & 10.24 | 3.85 & 9.73 |
| fast beam search | 32 | 4.75 & 11.1 | 4.48 & 10.65 | 4.12 & 10.18 | 3.95 & 9.67 |
| fast beam search | 64 | 4.7 & 11 | 4.37 & 10.49 | 4.07 & 10.04 | 3.89 & 9.53 |
| modified beam search | 32 | 4.64 & 10.94 | 4.38 & 10.51 | 4.11 & 10.14 | 3.87 & 9.61 |
| modified beam search | 64 | 4.59 & 10.81 | 4.29 & 10.39 | 4.02 & 10.02 | 3.84 & 9.43 |
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
The training command is:
```bash
./pruned_transducer_stateless/train.py \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 20 \
--num-left-chunks 4 \
--max-duration 300 \
--world-size 4 \
--start-epoch 0 \
--num-epochs 25
```
You can find the tensorboard log here <https://tensorboard.dev/experiment/ofxRakE6R7WHB1AoB8Bweg/>
The decoding command is:
```bash
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
for chunk in 2 4 8 16; do
for left in 32 64; do
./pruned_transducer_stateless/decode.py \
--simulate-streaming 1 \
--decode-chunk-size ${chunk} \
--left-context ${left} \
--causal-convolution 1 \
--epoch 24 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-sym-per-frame 1 \
--max-duration 1000 \
--decoding-method ${decoding_method}
done
done
```
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless_20220625>
#### [pruned_transducer_stateless2](./pruned_transducer_stateless2)
See <https://github.com/k2-fsa/icefall/pull/380> for more details.
##### Training on full librispeech
The WERs are (the number in the table formatted as test-clean & test-other):
We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|----------------------|--------------|----------------|----------------|----------------|----------------|
| greedy search | 32 | 4.2 & 10.64 | 3.97 & 10.03 | 3.83 & 9.58 | 3.7 & 9.11 |
| greedy search | 64 | 4.16 & 10.5 | 3.93 & 9.99 | 3.73 & 9.45 | 3.63 & 9.04 |
| fast beam search | 32 | 4.13 & 10.3 | 3.93 & 9.82 | 3.8 & 9.35 | 3.62 & 8.93 |
| fast beam search | 64 | 4.13 & 10.22 | 3.89 & 9.68 | 3.73 & 9.27 | 3.52 & 8.82 |
| modified beam search | 32 | 4.02 & 10.22 | 3.9 & 9.71 | 3.74 & 9.33 | 3.59 & 8.87 |
| modified beam search | 64 | 4.05 & 10.08 | 3.81 & 9.67 | 3.68 & 9.21 | 3.56 & 8.77 |
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless2/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
The training command is:
```bash
./pruned_transducer_stateless2/train.py \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 20 \
--num-left-chunks 4 \
--max-duration 300 \
--world-size 4 \
--start-epoch 0 \
--num-epochs 25
```
You can find the tensorboard log here <https://tensorboard.dev/experiment/hbltNS5TQ1Kiw0D1vcoakw/>
The decoding command is:
```bash
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
for chunk in 2 4 8 16; do
for left in 32 64; do
./pruned_transducer_stateless2/decode.py \
--simulate-streaming 1 \
--decode-chunk-size ${chunk} \
--left-context ${left} \
--causal-convolution 1 \
--epoch 24 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-sym-per-frame 1 \
--max-duration 1000 \
--decoding-method ${decoding_method}
done
done
```
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625>
#### [pruned_transducer_stateless3](./pruned_transducer_stateless3)
See <https://github.com/k2-fsa/icefall/pull/380> for more details.
##### Training on full librispeech (**Use giga_prob = 0.5**)
The WERs are (the number in the table formatted as test-clean & test-other):
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|----------------------|--------------|----------------|----------------|----------------|----------------|
| greedy search | 32 | 3.7 & 9.53 | 3.45 & 8.88 | 3.28 & 8.45 | 3.13 & 7.93 |
| greedy search | 64 | 3.69 & 9.36 | 3.39 & 8.68 | 3.28 & 8.19 | 3.08 & 7.83 |
| fast beam search | 32 | 3.71 & 9.18 | 3.36 & 8.65 | 3.23 & 8.23 | 3.17 & 7.78 |
| fast beam search | 64 | 3.61 & 9.03 | 3.46 & 8.43 | 3.2 & 8.0 | 3.11 & 7.63 |
| modified beam search | 32 | 3.56 & 9.08 | 3.34 & 8.58 | 3.21 & 8.14 | 3.06 & 7.73 |
| modified beam search | 64 | 3.55 & 8.86 | 3.29 & 8.34 | 3.16 & 8.01 | 3.05 & 7.57 |
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless3/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
The training command is (Note: this model was trained with mix-precision training):
```bash
./pruned_transducer_stateless3/train.py \
--exp-dir pruned_transducer_stateless3/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 32 \
--num-left-chunks 4 \
--max-duration 300 \
--world-size 4 \
--use-fp16 1 \
--start-epoch 0 \
--num-epochs 37 \
--num-workers 2 \
--giga-prob 0.5
```
You can find the tensorboard log here <https://tensorboard.dev/experiment/vL7dWVZqTYaSeoOED4rtow/>
The decoding command is:
```bash
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
for chunk in 2 4 8 16; do
for left in 32 64; do
./pruned_transducer_stateless3/decode.py \
--simulate-streaming 1 \
--decode-chunk-size ${chunk} \
--left-context ${left} \
--causal-convolution 1 \
--epoch 36 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-sym-per-frame 1 \
--max-duration 1000 \
--decoding-method ${decoding_method}
done
done
```
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless3_giga_0.5_20220625>
##### Training on full librispeech (**Use giga_prob = 0.9**)
The WERs are (the number in the table formatted as test-clean & test-other):
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|----------------------|--------------|----------------|----------------|----------------|----------------|
| greedy search | 32 | 3.25 & 8.2 | 3.07 & 7.67 | 2.91 & 7.28 | 2.8 & 6.89 |
| greedy search | 64 | 3.22 & 8.12 | 3.05 & 7.59 | 2.91 & 7.07 | 2.78 & 6.81 |
| fast beam search | 32 | 3.26 & 8.2 | 3.06 & 7.56 | 2.98 & 7.08 | 2.77 & 6.75 |
| fast beam search | 64 | 3.24 & 8.09 | 3.06 & 7.43 | 2.88 & 7.03 | 2.73 & 6.68 |
| modified beam search | 32 | 3.13 & 7.91 | 2.99 & 7.45 | 2.83 & 6.98 | 2.68 & 6.75 |
| modified beam search | 64 | 3.08 & 7.8 | 2.97 & 7.37 | 2.81 & 6.82 | 2.66 & 6.67 |
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless3/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
The training command is:
```bash
./pruned_transducer_stateless3/train.py \
--exp-dir pruned_transducer_stateless3/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 8 \
--max-duration 300 \
--world-size 8 \
--start-epoch 0 \
--num-epochs 26 \
--num-workers 2 \
--giga-prob 0.9
```
You can find the tensorboard log here <https://tensorboard.dev/experiment/WBGBDzt7SByRnvCBEfQpGQ/>
The decoding command is:
```bash
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
for chunk in 2 4 8 16; do
for left in 32 64; do
./pruned_transducer_stateless3/decode.py \
--simulate-streaming 1 \
--decode-chunk-size ${chunk} \
--left-context ${left} \
--causal-convolution 1 \
--epoch 25 \
--avg 12 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-sym-per-frame 1 \
--max-duration 1000 \
--decoding-method ${decoding_method}
done
done
```
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless3_giga_0.9_20220625>
#### [pruned_transducer_stateless4](./pruned_transducer_stateless4)
See <https://github.com/k2-fsa/icefall/pull/380> for more details.
##### Training on full librispeech
The WERs are (the number in the table formatted as test-clean & test-other):
We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|----------------------|--------------|----------------|----------------|----------------|----------------|
| greedy search | 32 | 3.96 & 10.45 | 3.73 & 9.97 | 3.54 & 9.56 | 3.45 & 9.08 |
| greedy search | 64 | 3.9 & 10.34 | 3.7 & 9.9 | 3.53 & 9.41 | 3.39 & 9.03 |
| fast beam search | 32 | 3.9 & 10.09 | 3.69 & 9.65 | 3.58 & 9.28 | 3.46 & 8.91 |
| fast beam search | 64 | 3.82 & 10.03 | 3.67 & 9.56 | 3.51 & 9.18 | 3.43 & 8.78 |
| modified beam search | 32 | 3.78 & 10.0 | 3.63 & 9.54 | 3.43 & 9.29 | 3.39 & 8.84 |
| modified beam search | 64 | 3.76 & 9.95 | 3.54 & 9.48 | 3.4 & 9.13 | 3.33 & 8.74 |
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless4/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
The training command is:
```bash
./pruned_transducer_stateless4/train.py \
--exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 20 \
--num-left-chunks 4 \
--max-duration 300 \
--world-size 4 \
--start-epoch 1 \
--num-epochs 25
```
You can find the tensorboard log here <https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/>
The decoding command is:
```bash
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
for chunk in 2 4 8 16; do
for left in 32 64; do
./pruned_transducer_stateless4/decode.py \
--simulate-streaming 1 \
--decode-chunk-size ${chunk} \
--left-context ${left} \
--causal-convolution 1 \
--epoch 25 \
--avg 3 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-sym-per-frame 1 \
--max-duration 1000 \
--decoding-method ${decoding_method}
done
done
```
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless4_20220625>
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)
[conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless)
@ -781,9 +1088,25 @@ The WERs are:
The train and decode commands are:
`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp --world-size 8 --num-epochs 26 --full-libri 1 --max-duration 300`
```bash
python3 ./pruned_transducer_stateless2/train.py \
--exp-dir=pruned_transducer_stateless2/exp \
--world-size 8 \
--num-epochs 26 \
--full-libri 1 \
--max-duration 300
```
and:
`python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp --epoch 25 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600`
```bash
python3 ./pruned_transducer_stateless2/decode.py \
--exp-dir pruned_transducer_stateless2/exp \
--epoch 25 \
--avg 8 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 600
```
The Tensorboard log is at <https://tensorboard.dev/experiment/Xoz0oABMTWewo1slNFXkyA> (apologies, log starts
only from epoch 3).
@ -796,9 +1119,26 @@ can be found at
#### Training on train-clean-100:
Trained with 1 job:
`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_ws1 --world-size 1 --num-epochs 40 --full-libri 0 --max-duration 300`
```
python3 ./pruned_transducer_stateless2/train.py \
--exp-dir=pruned_transducer_stateless2/exp_100h_ws1 \
--world-size 1 \
--num-epochs 40 \
--full-libri 0 \
--max-duration 300
```
and decoded with:
`python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp_100h_ws1 --epoch 19 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600`.
```
python3 ./pruned_transducer_stateless2/decode.py \
--exp-dir pruned_transducer_stateless2/exp_100h_ws1 \
--epoch 19 \
--avg 8 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 600
```
The Tensorboard log is at <https://tensorboard.dev/experiment/AhnhooUBRPqTnaggoqo7lg> (learning rate
schedule is not visible due to a since-fixed bug).
@ -812,9 +1152,26 @@ schedule is not visible due to a since-fixed bug).
| fast beam search | 6.53 | 16.82 | --epoch 39 --avg 10 --decoding-method fast_beam_search |
Trained with 2 jobs:
`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_ws2 --world-size 2 --num-epochs 40 --full-libri 0 --max-duration 300`
```bash
python3 ./pruned_transducer_stateless2/train.py \
--exp-dir=pruned_transducer_stateless2/exp_100h_ws2 \
--world-size 2 \
--num-epochs 40 \
--full-libri 0 \
--max-duration 300
```
and decoded with:
`python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp_100h_ws2 --epoch 19 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600`.
```
python3 ./pruned_transducer_stateless2/decode.py \
--exp-dir pruned_transducer_stateless2/exp_100h_ws2 \
--epoch 19 \
--avg 8 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 600
```
The Tensorboard log is at <https://tensorboard.dev/experiment/dvOC9wsrSdWrAIdsebJILg/>
(learning rate schedule is not visible due to a since-fixed bug).
@ -827,9 +1184,26 @@ The Tensorboard log is at <https://tensorboard.dev/experiment/dvOC9wsrSdWrAIdseb
Trained with 4 jobs:
`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_ws4 --world-size 4 --num-epochs 40 --full-libri 0 --max-duration 300`
```
python3 ./pruned_transducer_stateless2/train.py \
--exp-dir=pruned_transducer_stateless2/exp_100h_ws4 \
--world-size 4 \
--num-epochs 40 \
--full-libri 0 \
--max-duration 300
```
and decoded with:
`python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp_100h_ws4 --epoch 19 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600`.
```
python3 ./pruned_transducer_stateless2/decode.py \
--exp-dir pruned_transducer_stateless2/exp_100h_ws4 \
--epoch 19 \
--avg 8 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 600
```
The Tensorboard log is at <https://tensorboard.dev/experiment/a3T0TyC0R5aLj5bmFbRErA/>
@ -846,7 +1220,16 @@ The Tensorboard log is at <https://tensorboard.dev/experiment/a3T0TyC0R5aLj5bmFb
Trained with 1 job, with --use-fp16=True --max-duration=300 i.e. with half-precision
floats (but without increasing max-duration), after merging <https://github.com/k2-fsa/icefall/pull/305>.
Train command was
`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_fp16 --world-size 1 --num-epochs 40 --full-libri 0 --max-duration 300 --use-fp16 True`
```
python3 ./pruned_transducer_stateless2/train.py \
--exp-dir=pruned_transducer_stateless2/exp_100h_fp16 \
--world-size 1 \
--num-epochs 40 \
--full-libri 0 \
--max-duration 300 \
--use-fp16 True
```
The Tensorboard log is at <https://tensorboard.dev/experiment/DAtGG9lpQJCROUDwPNxwpA>
@ -860,7 +1243,16 @@ The Tensorboard log is at <https://tensorboard.dev/experiment/DAtGG9lpQJCROUDwPN
Trained with 1 job, with --use-fp16=True --max-duration=500, i.e. with half-precision
floats and max-duration increased from 300 to 500, after merging <https://github.com/k2-fsa/icefall/pull/305>.
Train command was
`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_fp16 --world-size 1 --num-epochs 40 --full-libri 0 --max-duration 500 --use-fp16 True`
```
python3 ./pruned_transducer_stateless2/train.py \
--exp-dir=pruned_transducer_stateless2/exp_100h_fp16 \
--world-size 1 \
--num-epochs 40 \
--full-libri 0 \
--max-duration 500 \
--use-fp16 True
```
The Tensorboard log is at <https://tensorboard.dev/experiment/Km7QBHYnSLWs4qQnAJWsaA>
@ -872,7 +1264,6 @@ The Tensorboard log is at <https://tensorboard.dev/experiment/Km7QBHYnSLWs4qQnAJ
### LibriSpeech BPE training results (Pruned Transducer)
Conformer encoder + non-current decoder. The decoder
@ -1299,17 +1690,18 @@ You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3x
#### 2021-11-09
The best WER, as of 2021-11-09, for the librispeech test dataset is below
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring):
The best WER, as of 2022-06-20, for the librispeech test dataset is below
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring + rnn lm rescoring):
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.42 | 5.73 |
| WER | 2.32 | 5.39 |
Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
| ngram_lm_scale | attention_scale |
|----------------|-----------------|
| 2.0 | 2.0 |
| ngram_lm_scale | attention_scale | rnn_lm_scale |
|----------------|-----------------|--------------|
| 0.3 | 2.1 | 2.2 |
To reproduce the above result, use the following commands for training:
@ -1330,11 +1722,27 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 0 \
--num-epochs 90
# Note: It trains for 90 epochs, but the best WER is at epoch-77.pt
# Train the RNN-LM
cd icefall
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./rnn_lm/train.py \
--exp-dir rnn_lm/exp_2048_3_tied \
--start-epoch 0 \
--world-size 4 \
--num-epochs 30 \
--use-fp16 1 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 3 \
--batch-size 500 \
--tie-weights true
```
and the following command for decoding
```
rnn_dir=$(git rev-parse --show-toplevel)/icefall/rnn_lm
./conformer_ctc/decode.py \
--exp-dir conformer_ctc/exp_500_att0.8 \
--lang-dir data/lang_bpe_500 \
@ -1344,13 +1752,23 @@ and the following command for decoding
--num-paths 1000 \
--epoch 77 \
--avg 55 \
--method attention-decoder \
--nbest-scale 0.5
--nbest-scale 0.5 \
--rnn-lm-exp-dir ${rnn_dir}/exp_2048_3_tied \
--rnn-lm-epoch 29 \
--rnn-lm-avg 3 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights true \
--method rnn-lm
```
You can find the pre-trained model by visiting
You can find the Conformer-CTC pre-trained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09>
and the RNN-LM pre-trained model:
<https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
The tensorboard log for training is available at
<https://tensorboard.dev/experiment/hZDWrZfaSqOMqtW0NEfXKg/#scalars>

View File

@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)

View File

@ -30,7 +30,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
@ -38,15 +38,19 @@ from icefall.decode import (
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -93,7 +97,9 @@ def get_parser():
is the decoding result.
- (5) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (6) nbest-oracle. Its WER is the lower bound of any n-best
- (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
you have trained an RNN LM using ./rnn_lm/train.py
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
@ -105,7 +111,7 @@ def get_parser():
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
""",
)
@ -116,7 +122,7 @@ def get_parser():
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
A smaller value results in more unique paths.
""",
)
@ -139,11 +145,67 @@ def get_parser():
"--lm-dir",
type=str,
default="data/lm",
help="""The LM dir.
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
return parser
@ -173,6 +235,7 @@ def get_params() -> AttributeDict:
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
@ -205,6 +268,8 @@ def decode_one_batch(
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
@ -330,6 +395,7 @@ def decode_one_batch(
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
@ -357,8 +423,6 @@ def decode_one_batch(
G_with_epsilon_loops=G,
lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
@ -370,6 +434,26 @@ def decode_one_batch(
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
lattice=rescored_lattice,
num_paths=params.num_paths,
rnn_lm_model=rnn_lm_model,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
blank_id=0,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -388,6 +472,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
rnn_lm_model: Optional[nn.Module],
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
@ -405,6 +490,8 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
rnn_lm_model:
The neural model for RNN LM.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
@ -442,6 +529,7 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
@ -490,7 +578,7 @@ def save_results(
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
enable_log = False
else:
@ -566,6 +654,10 @@ def main():
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
params.num_classes = num_classes
params.sos_id = sos_id
params.eos_id = eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
@ -590,6 +682,7 @@ def main():
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
@ -621,7 +714,11 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
if params.method in [
"whole-lattice-rescoring",
"attention-decoder",
"rnn-lm",
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
@ -648,20 +745,40 @@ def main():
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
rnn_lm_model = None
if params.method == "rnn-lm":
rnn_lm_model = RnnLmModel(
vocab_size=params.num_classes,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
@ -678,6 +795,7 @@ def main():
dl=test_dl,
params=params,
model=model,
rnn_lm_model=rnn_lm_model,
HLG=HLG,
H=H,
bpe_model=bpe_model,

View File

@ -363,7 +363,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)

View File

@ -1018,6 +1018,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1078,6 +1079,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1088,9 +1090,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -1098,7 +1097,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

119
egs/librispeech/ASR/distillation_with_hubert.sh Normal file → Executable file
View File

@ -1,3 +1,5 @@
#!/usr/bin/env bash
#
# A short introduction about distillation framework.
#
# A typical traditional distillation method is
@ -14,15 +16,15 @@
# teacher embeddings.
# 3. a middle layer 6(1-based) out of total 6 layers is used to extract
# student embeddings.
# This is an example to do distillation with librispeech clean-100 subset.
# run with command:
# bash distillation_with_hubert.sh [0|1|2|3|4]
#
# For example command
# bash distillation_with_hubert.sh 0
# will download hubert model.
stage=$1
# To directly download the extracted codebook indexes for model distillation, you can
# set stage=2, stop_stage=4, use_extracted_codebook=True
#
# To start from scratch, you can
# set stage=0, stop_stage=4, use_extracted_codebook=False
stage=0
stop_stage=4
# Set the GPUs available.
# This script requires at least one GPU.
@ -33,10 +35,35 @@ stage=$1
# export CUDA_VISIBLE_DEVICES="0"
#
# Suppose GPU 2,3,4,5 are available.
export CUDA_VISIBLE_DEVICES="2,3,4,5"
export CUDA_VISIBLE_DEVICES="0,1,2,3"
exp_dir=./pruned_transducer_stateless6/exp
mkdir -p $exp_dir
if [ $stage -eq 0 ]; then
# full_libri can be "True" or "False"
# "True" -> use full librispeech dataset for distillation
# "False" -> use train-clean-100 subset for distillation
full_libri=False
# use_extracted_codebook can be "True" or "False"
# "True" -> stage 0 and stage 1 would be skipped,
# and directly download the extracted codebook indexes for distillation
# "False" -> start from scratch
use_extracted_codebook=False
# teacher_model_id can be one of
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
# "hubert_xtralarge_ll60k" -> pretrained model without fintuing
teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == "True" ]; then
log "Stage 0: Download HuBERT model"
# Preparation stage.
# Install fairseq according to:
@ -45,7 +72,7 @@ if [ $stage -eq 0 ]; then
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
if [ $has_fairseq == 'False' ]; then
echo "Please install fairseq before running following stages"
log "Please install fairseq before running following stages"
exit 1
fi
@ -56,42 +83,41 @@ if [ $stage -eq 0 ]; then
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then
echo "Please install quantization before running following stages"
log "Please install quantization before running following stages"
exit 1
fi
echo "Download hubert model."
log "Download HuBERT model."
# Parameters about model.
exp_dir=./pruned_transducer_stateless6/exp/
model_id=hubert_xtralarge_ll60k_finetune_ls960
hubert_model_dir=${exp_dir}/hubert_models
hubert_model=${hubert_model_dir}/${model_id}.pt
hubert_model=${hubert_model_dir}/${teacher_model_id}.pt
mkdir -p ${hubert_model_dir}
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert
if [ -f ${hubert_model} ]; then
echo "hubert model alread exists."
log "HuBERT model alread exists."
else
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model}
wget -c https://dl.fbaipublicfiles.com/hubert/${teacher_model_id}.pt -P ${hubert_model_dir}
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir}
fi
fi
if [ ! -d ./data/fbank ]; then
echo "This script assumes ./data/fbank is already generated by prepare.sh"
log "This script assumes ./data/fbank is already generated by prepare.sh"
exit 1
fi
if [ $stage -eq 1 ]; then
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! "$use_extracted_codebook" == "True" ]; then
log "Stage 1: Verify that the downloaded HuBERT model is correct."
# This stage is not directly used by codebook indexes extraction.
# It is a method to "prove" that the downloaed hubert model
# is inferenced in an correct way if WERs look like normal.
# Expect WERs:
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ]
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ]
./pruned_transducer_stateless6/hubert_decode.py
./pruned_transducer_stateless6/hubert_decode.py --exp-dir $exp_dir
fi
if [ $stage -eq 2 ]; then
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# Analysis of disk usage:
# With num_codebooks==8, each teacher embedding is quantized into
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed.
@ -113,25 +139,61 @@ if [ $stage -eq 2 ]; then
# During quantizer's training data(teacher embedding) and it's training,
# only the first ONE GPU is used.
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used.
if [ "$use_extracted_codebook" == "True" ]; then
if [ ! "$teacher_model_id" == "hubert_xtralarge_ll60k_finetune_ls960" ]; then
log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960"
exit 1
fi
mkdir -p $exp_dir/vq
codebook_dir=$exp_dir/vq/$teacher_model_id
mkdir -p codebook_dir
codebook_download_dir=$exp_dir/download_codebook
if [ -d $codebook_download_dir ]; then
log "$codebook_download_dir exists, you should remove it first."
exit 1
fi
log "Downloading extracted codebook indexes to $codebook_download_dir"
# Make sure you have git-lfs installed (https://git-lfs.github.com)
git lfs install
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
mkdir -p data/vq_fbank
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
mkdir -p $codebook_dir/splits4
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
log "Remove $codebook_download_dir"
rm -rf $codebook_download_dir
fi
./pruned_transducer_stateless6/extract_codebook_index.py \
--full-libri False
--full-libri $full_libri \
--exp-dir $exp_dir \
--embedding-layer 36 \
--num-utts 1000 \
--num-codebooks 8 \
--max-duration 100 \
--teacher-model-id $teacher_model_id \
--use-extracted-codebook $use_extracted_codebook
fi
if [ $stage -eq 3 ]; then
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# Example training script.
# Note: it's better to set spec-aug-time-warpi-factor=-1
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \
--master-port 12359 \
--full-libri False \
--full-libri $full_libri \
--spec-aug-time-warp-factor -1 \
--max-duration 300 \
--world-size ${WORLD_SIZE} \
--num-epochs 20
--num-epochs 20 \
--exp-dir $exp_dir \
--enable-distillation True
fi
if [ $stage -eq 4 ]; then
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# Results should be similar to:
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60
@ -140,5 +202,6 @@ if [ $stage -eq 4 ]; then
--epoch 20 \
--avg 10 \
--max-duration 200 \
--exp-dir ./pruned_transducer_stateless6/exp
--exp-dir $exp_dir \
--enable-distillation True
fi

View File

@ -23,6 +23,7 @@ This file downloads the following LibriSpeech LM files:
- 4-gram.arpa.gz
- librispeech-vocab.txt
- librispeech-lexicon.txt
- librispeech-lm-norm.txt.gz
from http://www.openslr.org/resources/11
and save them in the user provided directory.
@ -61,6 +62,7 @@ def main(out_dir: str):
"4-gram.arpa.gz",
"librispeech-vocab.txt",
"librispeech-lexicon.txt",
"librispeech-lm-norm.txt.gz",
)
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):

View File

@ -0,0 +1,172 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
# Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes a `bpe.model` and a text file such as
./download/lm/librispeech-lm-norm.txt
and outputs the LM training data to a supplied directory such
as data/lm_training_bpe_500. The format is as follows:
It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
representation of a dict with the following format:
'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32
containing the BPE representations of each word, indexed by
integer word ID. (These integer word IDS are present in
'lm_data'). The sentencepiece object can be used to turn the
words and BPE units into string form.
'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype
torch.int32 containing all the sentences, as word-ids (we don't
output the string form of this directly but it can be worked out
together with 'words' and the bpe.model).
'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing
number of BPE tokens of each sentence.
"""
import argparse
import logging
from pathlib import Path
import k2
import sentencepiece as spm
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="Input BPE model, e.g. data/bpe_500/bpe.model",
)
parser.add_argument(
"--lm-data",
type=str,
help="""Input LM training data as text, e.g.
download/pb.train.txt""",
)
parser.add_argument(
"--lm-archive",
type=str,
help="""Path to output archive, e.g. data/bpe_500/lm_data.pt;
look at the source of this script to see the format.""",
)
return parser.parse_args()
def main():
args = get_args()
if Path(args.lm_archive).exists():
logging.warning(f"{args.lm_archive} exists - skipping")
return
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
# word2index is a dictionary from words to integer ids. No need to reserve
# space for epsilon, etc.; the words are just used as a convenient way to
# compress the sequences of BPE pieces.
word2index = dict()
word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces.
sentences = [] # Will be a list-of-list-of-int, representing word-ids.
if "librispeech-lm-norm" in args.lm_data:
num_lines_in_total = 40418261.0
step = 5000000
elif "valid" in args.lm_data:
num_lines_in_total = 5567.0
step = 3000
elif "test" in args.lm_data:
num_lines_in_total = 5559.0
step = 3000
else:
num_lines_in_total = None
step = None
processed = 0
with open(args.lm_data) as f:
while True:
line = f.readline()
if line == "":
break
if step and processed % step == 0:
logging.info(
f"Processed number of lines: {processed} "
f"({processed/num_lines_in_total*100: .3f}%)"
)
processed += 1
line_words = line.split()
for w in line_words:
if w not in word2index:
w_bpe = sp.encode(w)
word2index[w] = len(word2bpe)
word2bpe.append(w_bpe)
sentences.append([word2index[w] for w in line_words])
logging.info("Constructing ragged tensors")
words = k2.ragged.RaggedTensor(word2bpe)
sentences = k2.ragged.RaggedTensor(sentences)
output = dict(words=words, sentences=sentences)
num_sentences = sentences.dim0
logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}")
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
if step and i % step == 0:
logging.info(
f"Processed number of lines: {i} "
f"({i/num_sentences*100: .3f}%)"
)
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
output["sentence_lengths"] = torch.tensor(
sentence_lengths, dtype=torch.int32
)
torch.save(output, args.lm_archive)
logging.info(f"Saved to {args.lm_archive}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../ptb/LM/local/sort_lm_training_data.py

View File

@ -38,7 +38,6 @@ def get_args():
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the training corpus: transcript_words.txt.
The generated bpe.model is saved to this directory.
""",
)

View File

@ -24,6 +24,7 @@ stop_stage=100
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
# - librispeech-lm-norm.txt.gz
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
@ -40,9 +41,9 @@ dl_dir=$PWD/download
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
# 5000
# 2000
# 1000
5000
2000
1000
500
)
@ -278,3 +279,99 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
./local/compile_lg.py --lang-dir $lang_dir
done
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Generate LM training data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
lang_dir=data/lang_bpe_${vocab_size}
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $dl_dir/lm/librispeech-lm-norm.txt \
--lm-archive $out_dir/lm_data.pt
done
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Generate LM validation data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
if [ ! -f $out_dir/valid.txt ]; then
files=$(
find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $out_dir/valid.txt
fi
lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $out_dir/valid.txt \
--lm-archive $out_dir/lm_data-valid.pt
done
fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Generate LM test data"
for vocab_size in ${vocab_sizes[@]}; do
log "Processing vocab_size == ${vocab_size}"
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
if [ ! -f $out_dir/test.txt ]; then
files=$(
find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $out_dir/test.txt
fi
lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \
--bpe-model $lang_dir/bpe.model \
--lm-data $out_dir/test.txt \
--lm-archive $out_dir/lm_data-test.pt
done
fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Sort LM training data"
# Sort LM training data by sentence length in descending order
# for ease of training.
#
# Sentence length equals to the number of BPE tokens
# in a sentence.
for vocab_size in ${vocab_sizes[@]}; do
out_dir=data/lm_training_bpe_${vocab_size}
mkdir -p $out_dir
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-valid.pt \
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
--out-statistics $out_dir/statistics-valid.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data-test.pt \
--out-lm-data $out_dir/sorted_lm_data-test.pt \
--out-statistics $out_dir/statistics-test.txt
done
fi

View File

@ -75,6 +75,202 @@ def fast_beam_search_one_best(
return hyps
def fast_beam_search_nbest_LG(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores,
log_semiring=True, # Note: we always use True
)
# See https://github.com/k2-fsa/icefall/pull/420 for why
# we always use log_semiring=True
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,

View File

@ -50,22 +50,58 @@ Usage:
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search using LG
(5) fast beam search (nbest)
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--use-LG True \
--use-max False \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 8 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(6) decode in streaming mode (take greedy search as an example)
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method greedy_search
"""
@ -82,12 +118,15 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -153,7 +192,7 @@ def get_parser():
parser.add_argument(
"--lang-dir",
type=str,
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
@ -167,6 +206,11 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
@ -182,30 +226,13 @@ def get_parser():
parser.add_argument(
"--beam",
type=float,
default=4,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--use-LG",
type=str2bool,
default=False,
help="""Whether to use an LG graph for FSA-based beam search.
Used only when --decoding_method is fast_beam_search. If setting true,
it assumes there is an LG.pt file in lang_dir.""",
)
parser.add_argument(
"--use-max",
type=str2bool,
default=False,
help="""If True, use max-op to select the hypothesis that have the
max log_prob in case of duplicate hypotheses.
If False, use log_add.
Used only for beam_search, modified_beam_search, and fast_beam_search
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
@ -214,7 +241,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search.
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -222,9 +249,10 @@ def get_parser():
parser.add_argument(
"--max-contexts",
type=int,
default=4,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -232,7 +260,8 @@ def get_parser():
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -250,6 +279,47 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
@ -286,7 +356,8 @@ def decode_one_batch(
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -299,11 +370,21 @@ def decode_one_batch(
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -316,12 +397,51 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.use_LG:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
else:
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -339,7 +459,6 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
use_max=params.use_max,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -361,7 +480,6 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
use_max=params.use_max,
)
else:
raise ValueError(
@ -371,14 +489,17 @@ def decode_one_batch(
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -406,7 +527,8 @@ def decode_dataset(
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -424,7 +546,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 10
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -517,6 +639,9 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -526,17 +651,23 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-use-LG-{params.use_LG}"
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-use-max-{params.use_max}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-use-max-{params.use_max}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -558,6 +689,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
@ -596,12 +732,14 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
if params.use_LG:
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/LG.pt", map_location=device)
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:

View File

@ -0,0 +1,126 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple
import k2
import torch
from icefall.utils import AttributeDict
class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
initial_states:
Initial decode states of the model, e.g. the return value of
`get_init_state` in conformer.py
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Used only when decoding_method is fast_beam_search.
device:
The device to run this stream.
"""
if decoding_graph is not None:
assert device == decoding_graph.device
self.params = params
self.LOG_EPS = math.log(1e-10)
self.states = initial_states
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
self.num_frames: int = 0
# how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_processed_frames: int = 0
self._done: bool = False
# The transcript of current utterance.
self.ground_truth: str = ""
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
# how many frames have been processed, after subsampling (i.e. a
# cumulative sum of the second return value of
# encoder.streaming_forward
self.done_frames: int = 0
self.pad_length = (
params.right_context + 2
) * params.subsampling_factor + 3
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
)
else:
assert (
False
), f"Decoding method :{params.decoding_method} do not support."
@property
def done(self) -> bool:
"""Return True if all the features are processed."""
return self._done
def set_features(
self,
features: torch.Tensor,
) -> None:
"""Set features tensor of current utterance."""
assert features.dim() == 2, features.dim()
self.features = torch.nn.functional.pad(
features,
(0, 0, 0, self.pad_length),
mode="constant",
value=self.LOG_EPS,
)
self.num_frames = self.features.size(0)
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
chunk_length = chunk_size + self.pad_length
ret_length = min(
self.num_frames - self.num_processed_frames, chunk_length
)
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames # noqa
+ ret_length
]
self.num_processed_frames += chunk_size
if self.num_processed_frames >= self.num_frames:
self._done = True
return ret_features, ret_length

View File

@ -49,7 +49,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
@ -109,6 +109,17 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--streaming-model",
type=str2bool,
default=False,
help="""Whether to export a streaming model, if the models in exp-dir
are streaming model, this should be True.
""",
)
add_model_arguments(parser)
return parser
@ -130,8 +141,12 @@ def main():
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
logging.info(params)
logging.info("About to create model")

View File

@ -77,7 +77,9 @@ from beam_search import (
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
def get_parser():
@ -177,6 +179,29 @@ def get_parser():
--method is greedy_search.
""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser
@ -222,6 +247,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(f"{params}")
device = torch.device("cpu")
@ -268,9 +298,18 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=features,
x_lens=feature_lengths,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []

View File

@ -0,0 +1,678 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-size 8 \
--left-context 32 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 1000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Support only greedy_search and fast_beam_search now.
""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=0,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
rnnt_stream_list = []
processed_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
[
features,
torch.tensor(
LOG_EPS, dtype=features.dtype, device=device
).expand(
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
)
states = [
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_lens,
)
if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams)
elif params.decoding_method == "fast_beam_search":
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams
)
else:
assert False
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
decode_stream.set_features(fbank(samples.to(device)))
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
key = "greedy_search"
if params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
# sort results so we can easily compare the difference between two
# recognition results
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.causal_convolution = True
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -34,6 +34,31 @@ def test_model():
params.context_size = 2
params.unk_id = 2
params.dynamic_chunk_training = False
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = False
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def test_model_streaming():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.dynamic_chunk_training = True
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = True
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
@ -44,6 +69,7 @@ def test_model():
def main():
test_model()
test_model_streaming()
if __name__ == "__main__":

View File

@ -28,6 +28,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300
# train a streaming model
./pruned_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
"""
@ -73,6 +86,42 @@ from icefall.utils import (
)
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -222,6 +271,8 @@ def get_parser():
""",
)
add_model_arguments(parser)
return parser
@ -263,7 +314,7 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
@ -283,7 +334,7 @@ def get_params() -> AttributeDict:
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"encoder_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
@ -305,11 +356,15 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
d_model=params.encoder_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
)
return encoder
@ -762,6 +817,11 @@ def run(rank, world_size, args):
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
logging.info(params)
logging.info("About to create model")
@ -780,7 +840,7 @@ def run(rank, world_size, args):
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
model_size=params.encoder_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)

View File

@ -37,7 +37,7 @@ def fast_beam_search_one_best(
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
A lattice is first obtained using fast beam search, and then
the shortest path within the lattice is used as the final output.
Args:
@ -74,6 +74,202 @@ def fast_beam_search_one_best(
return hyps
def fast_beam_search_nbest_LG(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores,
log_semiring=True, # Note: we always use True
)
# See https://github.com/k2-fsa/icefall/pull/420 for why
# we always use log_semiring=True
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
The process to get the results is:
- (1) Use fast beam search to get a lattice
- (2) Select `num_paths` paths from the lattice using k2.random_paths()
- (3) Unique the selected paths
- (4) Intersect the selected paths with the lattice and compute the
shortest path from the intersection result
- (5) The path with the largest score is used as the decoding output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
use_double_scores:
True to use double precision for computation. False to use
single precision.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
@ -89,7 +285,7 @@ def fast_beam_search_nbest_oracle(
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
A lattice is first obtained using fast beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.

View File

@ -18,7 +18,7 @@
import copy
import math
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
@ -32,7 +32,7 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.utils import make_pad_mask, subsequent_chunk_mask
class Conformer(EncoderInterface):
@ -48,6 +48,26 @@ class Conformer(EncoderInterface):
layer_dropout (float): layer-dropout rate.
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
you want to train a streaming model, this is expected to be True.
When setting True, it will use a masking strategy to make the attention
see only limited left and right context.
short_chunk_threshold (float): a threshold to determinize the chunk size
to be used in masking training, if the randomly generated chunk size
is greater than ``max_len * short_chunk_threshold`` (max_len is the
max sequence length of current batch) then it will use
full context in training (i.e. with chunk size equals to max_len).
This will be used only when dynamic_chunk_training is True.
short_chunk_size (int): see docs above, if the randomly generated chunk
size equals to or less than ``max_len * short_chunk_threshold``, the
chunk size will be sampled uniformly from 1 to short_chunk_size.
This also will be used only when dynamic_chunk_training is True.
num_left_chunks (int): the left context (in chunks) attention can see, the
chunk size is decided by short_chunk_threshold and short_chunk_size.
A minus value means seeing full left context.
This also will be used only when dynamic_chunk_training is True.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training.
"""
def __init__(
@ -61,6 +81,11 @@ class Conformer(EncoderInterface):
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
num_left_chunks: int = -1,
causal: bool = False,
) -> None:
super(Conformer, self).__init__()
@ -76,6 +101,15 @@ class Conformer(EncoderInterface):
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_layers = num_encoder_layers
self.d_model = d_model
self.cnn_module_kernel = cnn_module_kernel
self.causal = causal
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.num_left_chunks = num_left_chunks
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
@ -85,8 +119,10 @@ class Conformer(EncoderInterface):
dropout,
layer_dropout,
cnn_module_kernel,
causal,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self._init_state: List[torch.Tensor] = [torch.empty(0)]
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -120,15 +156,249 @@ class Conformer(EncoderInterface):
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
src_key_padding_mask = make_pad_mask(lengths)
if self.dynamic_chunk_training:
assert (
self.causal
), "Causal convolution is required for streaming conformer."
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=self.num_left_chunks,
device=x.device,
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
else:
x = self.encoder(
x,
pos_emb,
mask=None,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
@torch.jit.export
def get_init_state(
self, left_context: int, device: torch.device
) -> List[torch.Tensor]:
"""Return the initial cache state of the model.
Args:
left_context: The left context size (in frames after subsampling).
Returns:
Return the initial state of the model, it is a list containing two
tensors, the first one is the cache for attentions which has a shape
of (num_encoder_layers, left_context, encoder_dim), the second one
is the cache of conv_modules which has a shape of
(num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
NOTE: the returned tensors are on the given device.
"""
if (
len(self._init_state) == 2
and self._init_state[0].size(1) == left_context
):
# Note: It is OK to share the init state as it is
# not going to be modified by the model
return self._init_state
init_states: List[torch.Tensor] = [
torch.zeros(
(
self.encoder_layers,
left_context,
self.d_model,
),
device=device,
),
torch.zeros(
(
self.encoder_layers,
self.cnn_module_kernel - 1,
self.d_model,
),
device=device,
),
]
self._init_state = init_states
return init_states
@torch.jit.export
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[Tensor]] = None,
processed_lens: Optional[Tensor] = None,
left_context: int = 64,
right_context: int = 4,
chunk_size: int = 16,
simulate_streaming: bool = False,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- decode_states, the updated states including the information
of current chunk.
"""
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4!
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming:
assert states is not None
assert processed_lens is not None
assert (
len(states) == 2
and states[0].shape
== (self.encoder_layers, left_context, x.size(0), self.d_model)
and states[1].shape
== (
self.encoder_layers,
self.cnn_module_kernel - 1,
x.size(0),
self.d_model,
)
), f"""The length of states MUST be equal to 2, and the shape of
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
given {states[0].shape}. the shape of second element should be
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
processed_lens = processed_lens.view(x.size(0), 1)
processed_mask = (processed_lens <= processed_mask).flip(1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
embed = self.encoder_embed(x)
# cut off 1 frame on each size of embed as they see the padding
# value which causes a training and decoding mismatch.
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x, states = self.encoder.chunk_forward(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
states=states,
left_context=left_context,
right_context=right_context,
) # (T, B, F)
if right_context > 0:
x = x[0:-right_context, ...]
lengths -= right_context
else:
assert states is None
states = [] # just to make torch.script.jit happy
# this branch simulates streaming decoding using mask as we are
# using in training time.
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item()
num_left_chunks = -1
if left_context >= 0:
assert left_context % chunk_size == 0
num_left_chunks = left_context // chunk_size
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=num_left_chunks,
device=x.device,
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
return x, lengths, states
class ConformerEncoderLayer(nn.Module):
@ -142,6 +412,8 @@ class ConformerEncoderLayer(nn.Module):
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -158,6 +430,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
causal: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
@ -185,7 +458,9 @@ class ConformerEncoderLayer(nn.Module):
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, causal=causal
)
self.norm_final = BasicNorm(d_model)
@ -214,7 +489,6 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
@ -248,10 +522,12 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = src + self.dropout(src_att)
# convolution module
src = src + self.dropout(self.conv_module(src))
conv, _ = self.conv_module(src)
src = src + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -263,6 +539,100 @@ class ConformerEncoderLayer(nn.Module):
return src
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (left_context, src.size(1), src.size(2))
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
# We put the attention cache this level (i.e. before linear transformation)
# to save memory consumption, when decoding in streaming fashion, the
# batch size would be thousands (for 32GB machine), if we cache key & val
# separately, it needs extra several GB memory.
# TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
# separately) if needed.
key = torch.cat([states[0], src], dim=0)
val = key
if right_context > 0:
states[0] = key[
-(left_context + right_context) : -right_context, ... # noqa
]
else:
states[0] = key[-left_context:, ...]
# multi-headed self-attention module
src_att = self.self_attn(
src,
key,
val,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
left_context=left_context,
)[0]
src = src + self.dropout(src_att)
# convolution module
conv, conv_cache = self.conv_module(src, states[1], right_context)
states[1] = conv_cache
src = src + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
return src, states
class ConformerEncoder(nn.Module):
r"""ConformerEncoder is a stack of N encoder layers
@ -301,6 +671,8 @@ class ConformerEncoder(nn.Module):
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
Shape:
src: (S, N, E).
@ -312,7 +684,7 @@ class ConformerEncoder(nn.Module):
"""
output = src
for i, mod in enumerate(self.layers):
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
@ -323,6 +695,79 @@ class ConformerEncoder(nn.Module):
return output
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (
self.num_layers,
left_context,
src.size(1),
src.size(2),
)
assert states[1].size(0) == self.num_layers
output = src
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output, cache = mod.chunk_forward(
output,
pos_emb,
states=cache,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
left_context=left_context,
right_context=right_context,
)
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output, states
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
@ -347,24 +792,25 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = x.size(1) + left_context
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
@ -382,22 +828,30 @@ class RelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self,
x: torch.Tensor,
left_context: int = 0,
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
self.extend_pe(x, left_context)
x_size_1 = x.size(1) + left_context
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
@ -469,6 +923,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -482,6 +937,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
- Inputs:
@ -527,14 +985,18 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
left_context=left_context,
)
def rel_shift(self, x: Tensor) -> Tensor:
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
@ -542,14 +1004,19 @@ class RelPositionMultiheadAttention(nn.Module):
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
time2 = time1 + left_context
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
@ -571,6 +1038,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -588,6 +1056,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
Inputs:
@ -751,7 +1222,8 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2
@ -771,9 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
q_with_bias_v, p
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
matrix_bd = self.rel_shift(matrix_bd, left_context)
attn_output_weights = (
matrix_ac + matrix_bd
@ -808,6 +1280,39 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
# If we are using dynamic_chunk_training and setting a limited
# num_left_chunks, the attention may only see the padding values which
# will also be masked out by `key_padding_mask`, at this circumstances,
# the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
# positions to avoid invalid loss value below.
if (
attn_mask is not None
and attn_mask.dtype == torch.bool
and key_padding_mask is not None
):
if attn_mask.size(0) != 1:
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
combined_mask = attn_mask | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
else:
# attn_mask.shape == (1, tgt_len, src_len)
combined_mask = attn_mask.unsqueeze(
0
) | key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
combined_mask, 0.0
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
@ -841,16 +1346,21 @@ class ConvolutionModule(nn.Module):
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
causal (bool): Whether to use causal convolution.
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = False,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.causal = causal
self.pointwise_conv1 = ScaledConv1d(
channels,
@ -878,12 +1388,17 @@ class ConvolutionModule(nn.Module):
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
self.lorder = kernel_size - 1
padding = (kernel_size - 1) // 2
if self.causal:
padding = 0
self.depthwise_conv = ScaledConv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
padding=padding,
groups=channels,
bias=bias,
)
@ -904,14 +1419,28 @@ class ConvolutionModule(nn.Module):
initial_scale=0.25,
)
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
cache: The cache of depthwise_conv, only used in real streaming
decoding.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Returns:
Tensor: Output tensor (#time, batch, channels).
If cache is None return the output tensor (#time, batch, channels).
If cache is not None, return a tuple of Tensor, the first one is
the output tensor (#time, batch, channels), the second one is the
new cache for next chunk (#kernel_size - 1, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
@ -924,6 +1453,26 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if self.causal and self.lorder > 0:
if cache is None:
# Make depthwise_conv causal by
# manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
assert (
not self.training
), "Cache should be None in training time"
assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
if right_context > 0:
cache = x.permute(2, 0, 1)[
-(self.lorder + right_context) : ( # noqa
-right_context
),
...,
]
else:
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
@ -931,7 +1480,11 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
# torch.jit.script requires return types be the same as annotated above
if cache is None:
cache = torch.empty(0)
return x.permute(2, 0, 1), cache
class Conv2dSubsampling(nn.Module):

View File

@ -43,21 +43,74 @@ Usage:
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
(4) fast beam search (one best)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) decode in streaming mode (take greedy search as an example)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method greedy_search
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -69,25 +122,32 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
@ -136,6 +196,13 @@ def get_parser():
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
@ -145,6 +212,11 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
@ -160,27 +232,42 @@ def get_parser():
parser.add_argument(
"--beam",
type=float,
default=4,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
default=64,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -190,6 +277,7 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
@ -198,6 +286,48 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
@ -206,6 +336,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -229,9 +360,12 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -246,9 +380,26 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -263,6 +414,49 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -318,6 +512,17 @@ def decode_one_batch(
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -327,6 +532,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -340,9 +546,12 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -360,7 +569,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 10
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -370,6 +579,7 @@ def decode_dataset(
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
)
@ -452,6 +662,9 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -461,10 +674,19 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -490,6 +712,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
@ -528,10 +755,24 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -553,6 +794,7 @@ def main():
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/decode_stream.py

View File

@ -73,6 +73,9 @@ class Decoder(nn.Module):
groups=decoder_dim,
bias=False,
)
else:
# It is to support torch script
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -49,7 +49,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -124,6 +124,16 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--streaming-model",
type=str2bool,
default=False,
help="""Whether to export a streaming model, if the models in exp-dir
are streaming model, this should be True.
""",
)
add_model_arguments(parser)
return parser
@ -147,6 +157,9 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
logging.info(params)
logging.info("About to create model")

View File

@ -52,8 +52,10 @@ class Joiner(nn.Module):
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(

View File

@ -77,7 +77,9 @@ from beam_search import (
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
def get_parser():
@ -178,6 +180,30 @@ def get_parser():
""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser
@ -222,6 +248,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(f"{params}")
device = torch.device("cpu")
@ -268,9 +299,18 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=features,
x_lens=feature_lengths,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []

View File

@ -52,7 +52,15 @@ class ActivationBalancerFunction(torch.autograd.Function):
if x.requires_grad:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# sum_dims = [d for d in range(x.ndim) if d != channel_dim]
# The above line is not torch scriptable for torch 1.6.0
# torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa
sum_dims = []
for d in range(x.ndim):
if d != channel_dim:
sum_dims.append(d)
xgt0 = x > 0
proportion_positive = torch.mean(
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
@ -214,8 +222,8 @@ class ScaledLinear(nn.Linear):
def get_bias(self):
if self.bias is None or self.bias_scale is None:
return None
return self.bias * self.bias_scale.exp()
else:
return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(
@ -234,6 +242,9 @@ class ScaledConv1d(nn.Conv1d):
):
super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self.bias_scale: Optional[nn.Parameter] # for torchscript
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
if self.bias is not None:
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
@ -262,7 +273,8 @@ class ScaledConv1d(nn.Conv1d):
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
else:
return bias * bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
@ -331,7 +343,8 @@ class ScaledConv2d(nn.Conv2d):
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
else:
return bias * bias_scale.exp()
def _conv_forward(self, input, weight):
F = torch.nn.functional
@ -412,16 +425,16 @@ class ActivationBalancer(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
return x
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):
@ -461,7 +474,8 @@ class DoubleSwish(torch.nn.Module):
"""
if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)
else:
return DoubleSwishFunction.apply(x)
class ScaledEmbedding(nn.Module):

View File

@ -0,0 +1,687 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 1000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Support only greedy_search and fast_beam_search now.
""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=0,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> List[List[int]]:
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# logging.info(f"decoder_out shape : {decoder_out.shape}")
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
hyp_tokens = []
for stream in streams:
hyp_tokens.append(stream.hyp)
return hyp_tokens
def fast_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
decoding_streams: k2.RnntDecodingStreams,
) -> List[List[int]]:
B, T, C = encoder_out.shape
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
return hyp_tokens
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
rnnt_stream_list = []
processed_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
if params.decoding_method == "fast_beam_search":
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
[
features,
torch.tensor(
LOG_EPS, dtype=features.dtype, device=device
).expand(
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
)
states = [
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_lens,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
hyp_tokens = greedy_search(model, encoder_out, decode_streams)
elif params.decoding_method == "fast_beam_search":
config = k2.RnntDecodingConfig(
vocab_size=params.vocab_size,
decoder_history_len=params.context_size,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_lens + encoder_out_lens
hyp_tokens = fast_beam_search(
model, encoder_out, processed_lens, decoding_streams
)
else:
assert False
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(hyp).split(),
)
)
del decode_streams[i]
key = "greedy_search"
if params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
# sort results so we can easily compare the difference between two
# recognition results
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -1,50 +0,0 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless2/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/test_model.py

View File

@ -40,6 +40,18 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \
--max-duration 550
# train a streaming model
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
"""
@ -83,6 +95,42 @@ LRSchedulerType = Union[
]
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -263,6 +311,8 @@ def get_parser():
help="Whether to use half precision training.",
)
add_model_arguments(parser)
return parser
@ -349,6 +399,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
)
return encoder
@ -806,6 +860,11 @@ def run(rank, world_size, args):
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
logging.info(params)
logging.info("About to create model")
@ -883,6 +942,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -973,6 +1033,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -983,9 +1044,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -993,7 +1051,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -291,7 +291,6 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=self.args.num_buckets,
drop_last=True,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(

Some files were not shown because too many files have changed in this diff Show More