Changes from upstream/master

This commit is contained in:
Daniel Povey 2022-03-16 21:49:15 +08:00
commit 87c92efbfe
166 changed files with 14825 additions and 3150 deletions

View File

@ -6,6 +6,8 @@ per-file-ignores =
# line too long # line too long
egs/librispeech/ASR/*/conformer.py: E501, egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501, egs/aishell/ASR/*/conformer.py: E501,
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605
exclude = exclude =
.git, .git,

View File

@ -0,0 +1,180 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-2022-03-12
# stateless transducer + k2 pruned rnnt-loss
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_librispeech_2022_03_12:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs
mkdir -p ~/tmp
cd ~/tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
- name: Display test files
shell: bash
run: |
sudo apt-get -qq install tree sox
tree ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
soxi ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav
ls -lh ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav

View File

@ -31,9 +31,6 @@ jobs:
matrix: matrix:
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9] python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false fail-fast: false
@ -42,30 +39,43 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz - name: Install graphviz
shell: bash shell: bash
run: | run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model - name: Download pre-trained model
shell: bash shell: bash
run: | run: |
@ -83,7 +93,9 @@ jobs:
- name: Run CTC decoding - name: Run CTC decoding
shell: bash shell: bash
run: | run: |
export PYTHONPATH=$PWD:PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \ ./conformer_ctc/pretrained.py \
--num-classes 500 \ --num-classes 500 \
@ -98,6 +110,8 @@ jobs:
shell: bash shell: bash
run: | run: |
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \ ./conformer_ctc/pretrained.py \
--num-classes 500 \ --num-classes 500 \

View File

@ -0,0 +1,172 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-100h
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav

View File

@ -0,0 +1,174 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-960h
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav

View File

@ -0,0 +1,173 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-trandsucer-stateless-modified-2-aishell
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_transducer_stateless_modified_2_aishell:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/aishell/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01
cd ..
tree tmp
soxi tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav
ls -lh tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav

View File

@ -0,0 +1,173 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-trandsucer-stateless-modified-aishell
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_transducer_stateless_modified_aishell:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/aishell/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
cd ..
tree tmp
soxi tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav
ls -lh tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav

View File

@ -31,9 +31,6 @@ jobs:
matrix: matrix:
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9] python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false fail-fast: false
@ -42,30 +39,43 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz - name: Install graphviz
shell: bash shell: bash
run: | run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model - name: Download pre-trained model
shell: bash shell: bash
run: | run: |
@ -74,35 +84,88 @@ jobs:
mkdir tmp mkdir tmp
cd tmp cd tmp
git lfs install git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
cd .. cd ..
tree tmp tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
- name: Run greedy search decoding - name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash shell: bash
run: | run: |
export PYTHONPATH=$PWD:PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method greedy_search \ --method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ --max-sym-per-frame 1 \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding - name: Run beam search decoding
shell: bash shell: bash
run: | run: |
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav

View File

@ -31,9 +31,6 @@ jobs:
matrix: matrix:
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9] python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false fail-fast: false
@ -42,30 +39,43 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz - name: Install graphviz
shell: bash shell: bash
run: | run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model - name: Download pre-trained model
shell: bash shell: bash
run: | run: |
@ -84,7 +94,9 @@ jobs:
- name: Run greedy search decoding - name: Run greedy search decoding
shell: bash shell: bash
run: | run: |
export PYTHONPATH=$PWD:PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR cd egs/librispeech/ASR
./transducer/pretrained.py \ ./transducer/pretrained.py \
--method greedy_search \ --method greedy_search \
@ -98,6 +110,8 @@ jobs:
shell: bash shell: bash
run: | run: |
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR cd egs/librispeech/ASR
./transducer/pretrained.py \ ./transducer/pretrained.py \
--method beam_search \ --method beam_search \

View File

@ -33,9 +33,6 @@ jobs:
# TODO: enable macOS for CPU testing # TODO: enable macOS for CPU testing
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.8] python-version: [3.8]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false fail-fast: false
steps: steps:
@ -43,10 +40,17 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }} - name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1 uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install libnsdfile and libsox - name: Install libnsdfile and libsox
if: startsWith(matrix.os, 'ubuntu') if: startsWith(matrix.os, 'ubuntu')
@ -57,13 +61,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install -U pip grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
# We are in ./icefall and there is a file: requirements.txt in it
python3 -m pip install -r requirements.txt
- name: Run yesno recipe - name: Run yesno recipe
shell: bash shell: bash

View File

@ -80,16 +80,16 @@ We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open
Using Conformer as encoder. The decoder consists of 1 embedding layer Using Conformer as encoder. The decoder consists of 1 embedding layer
and 1 convolutional layer. and 1 convolutional layer.
The best WER using beam search with beam size 4 is: The best WER using modified beam search with beam size 4 is:
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 2.68 | 6.72 | | WER | 2.56 | 6.27 |
Note: No auxiliary losses are used in the training and no LMs are used Note: No auxiliary losses are used in the training and no LMs are used
in the decoding. in the decoding.
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing) We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
### Aishell ### Aishell
@ -113,7 +113,7 @@ The best CER we currently have is:
| | test | | | test |
|-----|------| |-----|------|
| CER | 5.7 | | CER | 4.68 |
We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing) We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)

View File

@ -33,6 +33,7 @@ release = "0.1"
# ones. # ones.
extensions = [ extensions = [
"sphinx_rtd_theme", "sphinx_rtd_theme",
"sphinx.ext.todo",
] ]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
@ -74,3 +75,5 @@ html_context = {
"github_version": "master", "github_version": "master",
"conf_py_path": "/icefall/docs/source/", "conf_py_path": "/icefall/docs/source/",
} }
todo_include_todos = True

View File

@ -0,0 +1,4 @@
# Introduction
<https://shields.io/> is used to generate files in this directory.

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="80" height="20" role="img" aria-label="k2: &gt;= v1.9"><title>k2: &gt;= v1.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="80" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="57" height="20" fill="blueviolet"/><rect width="80" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">&gt;= v1.9</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="470">&gt;= v1.9</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="58" height="20" role="img" aria-label="k2: v1.9"><title>k2: v1.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="58" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="35" height="20" fill="blueviolet"/><rect width="58" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="395" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="250">v1.9</text><text x="395" y="140" transform="scale(.1)" fill="#fff" textLength="250">v1.9</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="170" height="20" role="img" aria-label="python: 3.6 | 3.7 | 3.8 | 3.9"><title>python: 3.6 | 3.7 | 3.8 | 3.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="170" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="121" height="20" fill="#007ec6"/><rect width="170" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="1085" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="1110">3.6 | 3.7 | 3.8 | 3.9</text><text x="1085" y="140" transform="scale(.1)" fill="#fff" textLength="1110">3.6 | 3.7 | 3.8 | 3.9</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="98" height="20" role="img" aria-label="python: &gt;= 3.6"><title>python: &gt;= 3.6</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="98" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="49" height="20" fill="#555"/><rect x="49" width="49" height="20" fill="#007ec6"/><rect width="98" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="255" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">python</text><text x="255" y="140" transform="scale(.1)" fill="#fff" textLength="390">python</text><text aria-hidden="true" x="725" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="390">&gt;= 3.6</text><text x="725" y="140" transform="scale(.1)" fill="#fff" textLength="390">&gt;= 3.6</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="286" height="20" role="img" aria-label="torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0"><title>torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="286" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="247" height="20" fill="#97ca00"/><rect width="286" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="1615" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="2370">1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</text><text x="1615" y="140" transform="scale(.1)" fill="#fff" textLength="2370">1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0</text></g></svg>

Before

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="100" height="20" role="img" aria-label="torch: &gt;= 1.6.0"><title>torch: &gt;= 1.6.0</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="100" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="39" height="20" fill="#555"/><rect x="39" width="61" height="20" fill="#97ca00"/><rect width="100" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="205" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="290">torch</text><text x="205" y="140" transform="scale(.1)" fill="#fff" textLength="290">torch</text><text aria-hidden="true" x="685" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="510">&gt;= 1.6.0</text><text x="685" y="140" transform="scale(.1)" fill="#fff" textLength="510">&gt;= 1.6.0</text></g></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -15,13 +15,13 @@ Installation
.. |device| image:: ./images/device-CPU_CUDA-orange.svg .. |device| image:: ./images/device-CPU_CUDA-orange.svg
:alt: Supported devices :alt: Supported devices
.. |python_versions| image:: ./images/python-3.6_3.7_3.8_3.9-blue.svg .. |python_versions| image:: ./images/python-gt-v3.6-blue.svg
:alt: Supported python versions :alt: Supported python versions
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg .. |torch_versions| image:: ./images/torch-gt-v1.6.0-green.svg
:alt: Supported PyTorch versions :alt: Supported PyTorch versions
.. |k2_versions| image:: ./images/k2-v1.9-blueviolet.svg .. |k2_versions| image:: ./images/k2-gt-v1.9-blueviolet.svg
:alt: Supported k2 versions :alt: Supported k2 versions
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and ``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and

View File

@ -1,10 +0,0 @@
Aishell
=======
We provide the following models for the Aishell dataset:
.. toctree::
:maxdepth: 2
aishell/conformer_ctc
aishell/tdnn_lstm_ctc

View File

@ -1,4 +1,4 @@
Confromer CTC Conformer CTC
============= =============
This tutorial shows you how to run a conformer ctc model This tutorial shows you how to run a conformer ctc model

Binary file not shown.

After

Width:  |  Height:  |  Size: 441 KiB

View File

@ -0,0 +1,22 @@
aishell
=======
Aishell is an open-source Chinese Mandarin speech corpus published by Beijing
Shell Shell Technology Co.,Ltd.
400 people from different accent areas in China are invited to participate in
the recording, which is conducted in a quiet indoor environment using high
fidelity microphone and downsampled to 16kHz. The manual transcription accuracy
is above 95%, through professional speech annotation and strict quality
inspection. The data is free for academic use. We hope to provide moderate
amount of data for new researchers in the field of speech recognition.
It can be downloaded from `<https://www.openslr.org/33/>`_
.. toctree::
:maxdepth: 1
tdnn_lstm_ctc
conformer_ctc
stateless_transducer

View File

@ -0,0 +1,714 @@
Stateless Transducer
====================
This tutorial shows you how to do transducer training in ``icefall``.
.. HINT::
Instead of using RNN-T or RNN transducer, we only use transducer
here. As you will see, there are no RNNs in the model.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
We recommend you to use a GPU or several GPUs to run this recipe.
In this tutorial, you will learn:
- (1) What does the transducer model look like
- (2) How to prepare data for training and decoding
- (3) How to start the training, either with a single GPU or with multiple GPUs
- (4) How to do decoding after training, with greedy search, beam search and, **modified beam search**
- (5) How to use a pre-trained model provided by us to transcribe sound files
The Model
---------
The transducer model consists of 3 parts:
- **Encoder**: It is a conformer encoder with the following parameters
- Number of heads: 8
- Attention dim: 512
- Number of layers: 12
- Feedforward dim: 2048
- **Decoder**: We use a stateless model consisting of:
- An embedding layer with embedding dim 512
- A Conv1d layer with a default kernel size 2 (i.e. it sees 2
symbols of left-context by default)
- **Joiner**: It consists of a ``nn.tanh()`` and a ``nn.Linear()``.
.. Caution::
The decoder is stateless and very simple. It is borrowed from
`<https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419>`_
(Rnn-Transducer with Stateless Prediction Network)
We make one modification to it: Place a Conv1d layer right after
the embedding layer.
When using Chinese characters as modelling unit, whose vocabulary size
is 4336 in this specific dataset,
the number of parameters of the model is ``87939824``, i.e., about ``88 M``.
The Loss
--------
We are using `<https://github.com/csukuangfj/optimized_transducer>`_
to compute the transducer loss, which removes extra paddings
in loss computation to save memory.
.. Hint::
``optimized_transducer`` implements the technqiues proposed
in `Improving RNN Transducer Modeling for End-to-End Speech Recognition <https://arxiv.org/abs/1909.12415>`_ to save memory.
Furthermore, it supports ``modified transducer``, limiting the maximum
number of symbols that can be emitted per frame to 1, which simplifies
the decoding process significantly. Also, the experiment results
show that it does not degrade the performance.
See `<https://github.com/csukuangfj/optimized_transducer#modified-transducer>`_
for what exactly modified transducer is.
`<https://github.com/csukuangfj/transducer-loss-benchmarking>`_ shows that
in the unpruned case ``optimized_transducer`` has the advantage about minimizing
memory usage.
.. todo::
Add tutorial about ``pruned_transducer_stateless`` that uses k2
pruned transducer loss.
.. hint::
You can use::
pip install optimized_transducer
to install ``optimized_transducer``. Refer to
`<https://github.com/csukuangfj/optimized_transducer>`_ for other
alternatives.
Data Preparation
----------------
To prepare the data for training, please use the following commands:
.. code-block:: bash
cd egs/aishell/ASR
./prepare.sh --stop-stage 4
./prepare.sh --stage 6 --stop-stage 6
.. note::
You can use ``./prepare.sh``, though it will generate FSTs that
are not used in transducer training.
When you finish running the script, you will get the following two folders:
- ``data/fbank``: It saves the pre-computed features
- ``data/lang_char``: It contains tokens that will be used in the training
Training
--------
.. code-block:: bash
cd egs/aishell/ASR
./transducer_stateless_modified/train.py --help
shows you the training options that can be passed from the commandline.
The following options are used quite often:
- ``--exp-dir``
The experiment folder to save logs and model checkpoints,
defaults to ``./transducer_stateless_modified/exp``.
- ``--num-epochs``
It is the number of epochs to train. For instance,
``./transducer_stateless_modified/train.py --num-epochs 30`` trains for 30
epochs and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
in the folder set by ``--exp-dir``.
- ``--start-epoch``
It's used to resume training.
``./transducer_stateless_modified/train.py --start-epoch 10`` loads the
checkpoint from ``exp_dir/epoch-9.pt`` and starts
training from epoch 10, based on the state from epoch 9.
- ``--world-size``
It is used for single-machine multi-GPU DDP training.
- (a) If it is 1, then no DDP training is used.
- (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
The following shows some use cases with it.
**Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
GPU 2 for training. You can do the following:
.. code-block:: bash
$ cd egs/aishell/ASR
$ export CUDA_VISIBLE_DEVICES="0,2"
$ ./transducer_stateless_modified/train.py --world-size 2
**Use case 2**: You have 4 GPUs and you want to use all of them
for training. You can do the following:
.. code-block:: bash
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/train.py --world-size 4
**Use case 3**: You have 4 GPUs but you only want to use GPU 3
for training. You can do the following:
.. code-block:: bash
$ cd egs/aishell/ASR
$ export CUDA_VISIBLE_DEVICES="3"
$ ./transducer_stateless_modified/train.py --world-size 1
.. CAUTION::
Only single-machine multi-GPU DDP training is implemented at present.
There is an on-going PR `<https://github.com/k2-fsa/icefall/pull/63>`_
that adds support for multi-machine multi-GPU DDP training.
- ``--max-duration``
It specifies the number of seconds over all utterances in a
batch **before padding**.
If you encounter CUDA OOM, please reduce it. For instance, if
your are using V100 NVIDIA GPU with 32 GB RAM, we recommend you
to set it to ``300`` when the vocabulary size is 500.
.. HINT::
Due to padding, the number of seconds of all utterances in a
batch will usually be larger than ``--max-duration``.
A larger value for ``--max-duration`` may cause OOM during training,
while a smaller value may increase the training time. You have to
tune it.
- ``--lr-factor``
It controls the learning rate. If you use a single GPU for training, you
may want to use a small value for it. If you use multiple GPUs for training,
you may increase it.
- ``--context-size``
It specifies the kernel size in the decoder. The default value 2 means it
functions as a tri-gram LM.
- ``--modified-transducer-prob``
It specifies the probability to use modified transducer loss.
If it is 0, then no modified transducer is used; if it is 1,
then it uses modified transducer loss for all batches. If it is
``p``, it applies modified transducer with probability ``p``.
There are some training options, e.g.,
number of warmup steps,
that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in
`transducer_stateless_modified/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/transducer_stateless_modified/train.py#L162>`_
If you need to change them, please modify ``./transducer_stateless_modified/train.py`` directly.
.. CAUTION::
The training set is perturbed by speed with two factors: 0.9 and 1.1.
Each epoch actually processes ``3x150 == 450`` hours of data.
Training logs
~~~~~~~~~~~~~
Training logs and checkpoints are saved in the folder set by ``--exp-dir``
(defaults to ``transducer_stateless_modified/exp``). You will find the following files in that directory:
- ``epoch-0.pt``, ``epoch-1.pt``, ...
These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
.. code-block:: bash
$ ./transducer_stateless_modified/train.py --start-epoch 11
- ``tensorboard/``
This folder contains TensorBoard logs. Training loss, validation loss, learning
rate, etc, are recorded in these logs. You can visualize them by:
.. code-block:: bash
$ cd transducer_stateless_modified/exp/tensorboard
$ tensorboard dev upload --logdir . --name "Aishell transducer training with icefall" --description "Training modified transducer, see https://github.com/k2-fsa/icefall/pull/219"
It will print something like below:
.. code-block::
TensorFlow installation not found - running with reduced feature set.
Upload started and will continue reading any new data as it's added to the logdir.
To stop uploading, press Ctrl-C.
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/laGZ6HrcQxOigbFD5E0Y3Q/
[2022-03-03T14:29:45] Started scanning logdir.
[2022-03-03T14:29:48] Total uploaded: 8477 scalars, 0 tensors, 0 binary objects
Listening for new data in logdir...
Note there is a `URL <https://tensorboard.dev/experiment/laGZ6HrcQxOigbFD5E0Y3Q/>`_ in the
above output, click it and you will see the following screenshot:
.. figure:: images/aishell-transducer_stateless_modified-tensorboard-log.png
:width: 600
:alt: TensorBoard screenshot
:align: center
:target: https://tensorboard.dev/experiment/laGZ6HrcQxOigbFD5E0Y3Q
TensorBoard screenshot.
- ``log/log-train-xxxx``
It is the detailed training log in text format, same as the one
you saw printed to the console during training.
Usage examples
~~~~~~~~~~~~~~
The following shows typical use cases:
**Case 1**
^^^^^^^^^^
.. code-block:: bash
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/train.py --max-duration 250
It uses ``--max-duration`` of 250 to avoid OOM.
**Case 2**
^^^^^^^^^^
.. code-block:: bash
$ cd egs/aishell/ASR
$ export CUDA_VISIBLE_DEVICES="0,3"
$ ./transducer_stateless_modified/train.py --world-size 2
It uses GPU 0 and GPU 3 for DDP training.
**Case 3**
^^^^^^^^^^
.. code-block:: bash
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/train.py --num-epochs 10 --start-epoch 3
It loads checkpoint ``./transducer_stateless_modified/exp/epoch-2.pt`` and starts
training from epoch 3. Also, it trains for 10 epochs.
Decoding
--------
The decoding part uses checkpoints saved by the training part, so you have
to run the training part first.
.. code-block:: bash
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/decode.py --help
shows the options for decoding.
The commonly used options are:
- ``--method``
This specifies the decoding method. Currently, it supports:
- **greedy_search**. You can provide the commandline option ``--max-sym-per-frame``
to limit the maximum number of symbols that can be emitted per frame.
- **beam_search**. You can provide the commandline option ``--beam-size``.
- **modified_beam_search**. You can also provide the commandline option ``--beam-size``.
To use this method, we assume that you have trained your model with modified transducer,
i.e., used the option ``--modified-transducer-prob`` in the training.
The following command uses greedy search for decoding
.. code-block::
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/decode.py \
--epoch 64 \
--avg 33 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method greedy_search \
--max-sym-per-frame 1
The following command uses beam search for decoding
.. code-block::
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/decode.py \
--epoch 64 \
--avg 33 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
The following command uses ``modified`` beam search for decoding
.. code-block::
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/decode.py \
--epoch 64 \
--avg 33 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
- ``--max-duration``
It has the same meaning as the one used in training. A larger
value may cause OOM.
- ``--epoch``
It specifies the checkpoint from which epoch that should be used for decoding.
- ``--avg``
It specifies the number of models to average. For instance, if it is 3 and if
``--epoch=10``, then it averages the checkpoints ``epoch-8.pt``, ``epoch-9.pt``,
and ``epoch-10.pt`` and the averaged checkpoint is used for decoding.
After decoding, you can find the decoding logs and results in `exp_dir/log/<decoding_method>`, e.g.,
``exp_dir/log/greedy_search``.
Pre-trained Model
-----------------
We have uploaded a pre-trained model to
`<https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01>`_
We describe how to use the pre-trained model to transcribe a sound file or
multiple sound files in the following.
Install kaldifeat
~~~~~~~~~~~~~~~~~
`kaldifeat <https://github.com/csukuangfj/kaldifeat>`_ is used to
extract features for a single sound file or multiple sound files
at the same time.
Please refer to `<https://github.com/csukuangfj/kaldifeat>`_ for installation.
Download the pre-trained model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The following commands describe how to download the pre-trained model:
.. code-block::
$ cd egs/aishell/ASR
$ mkdir tmp
$ cd tmp
$ git lfs install
$ git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
.. CAUTION::
You have to use ``git lfs`` to download the pre-trained model.
After downloading, you will have the following files:
.. code-block:: bash
$ cd egs/aishell/ASR
$ tree tmp/icefall-aishell-transducer-stateless-modified-2022-03-01
.. code-block:: bash
tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/
|-- README.md
|-- data
| `-- lang_char
| |-- L.pt
| |-- lexicon.txt
| |-- tokens.txt
| `-- words.txt
|-- exp
| `-- pretrained.pt
|-- log
| |-- errs-test-beam_4-epoch-64-avg-33-beam-4.txt
| |-- errs-test-greedy_search-epoch-64-avg-33-context-2-max-sym-per-frame-1.txt
| |-- log-decode-epoch-64-avg-33-beam-4-2022-03-02-12-05-03
| |-- log-decode-epoch-64-avg-33-context-2-max-sym-per-frame-1-2022-02-28-18-13-07
| |-- recogs-test-beam_4-epoch-64-avg-33-beam-4.txt
| `-- recogs-test-greedy_search-epoch-64-avg-33-context-2-max-sym-per-frame-1.txt
`-- test_wavs
|-- BAC009S0764W0121.wav
|-- BAC009S0764W0122.wav
|-- BAC009S0764W0123.wav
`-- transcript.txt
5 directories, 16 files
**File descriptions**:
- ``data/lang_char``
It contains language related files. You can find the vocabulary size in ``tokens.txt``.
- ``exp/pretrained.pt``
It contains pre-trained model parameters, obtained by averaging
checkpoints from ``epoch-32.pt`` to ``epoch-64.pt``.
Note: We have removed optimizer ``state_dict`` to reduce file size.
- ``log``
It contains decoding logs and decoded results.
- ``test_wavs``
It contains some test sound files from Aishell ``test`` dataset.
The information of the test sound files is listed below:
.. code-block:: bash
$ soxi tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav
Input File : 'tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.20 = 67263 samples ~ 315.295 CDDA sectors
File Size : 135k
Bit Rate : 256k
Sample Encoding: 16-bit Signed Integer PCM
Input File : 'tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.12 = 65840 samples ~ 308.625 CDDA sectors
File Size : 132k
Bit Rate : 256k
Sample Encoding: 16-bit Signed Integer PCM
Input File : 'tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.00 = 64000 samples ~ 300 CDDA sectors
File Size : 128k
Bit Rate : 256k
Sample Encoding: 16-bit Signed Integer PCM
Total Duration of 3 files: 00:00:12.32
Usage
~~~~~
.. code-block::
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/pretrained.py --help
displays the help information.
It supports three decoding methods:
- greedy search
- beam search
- modified beam search
.. note::
In modified beam search, it limits the maximum number of symbols that can be
emitted per frame to 1. To use this method, you have to ensure that your model
has been trained with the option ``--modified-transducer-prob``. Otherwise,
it may give you poor results.
Greedy search
^^^^^^^^^^^^^
The command to run greedy search is given below:
.. code-block:: bash
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/pretrained.py \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
--method greedy_search \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
The output is as follows:
.. code-block::
2022-03-03 15:35:26,531 INFO [pretrained.py:239] device: cuda:0
2022-03-03 15:35:26,994 INFO [lexicon.py:176] Loading pre-compiled tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char/Linv.pt
2022-03-03 15:35:27,027 INFO [pretrained.py:246] {'feature_dim': 80, 'encoder_out_dim': 512, 'subsampling_factor': 4, 'attention_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'vgg_frontend': False, 'env_info': {'k2-version': '1.13', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'f4fefe4882bc0ae59af951da3f47335d5495ef71', 'k2-git-date': 'Thu Feb 10 15:16:02 2022', 'lhotse-version': '1.0.0.dev+missing.version.file', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '50d2281-clean', 'icefall-git-date': 'Wed Mar 2 16:02:38 2022', 'icefall-path': '/ceph-fj/fangjun/open-source-2/icefall-aishell', 'k2-path': '/ceph-fj/fangjun/open-source-2/k2-multi-datasets/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-aishell/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-2-0815224919-75d558775b-mmnv8', 'IP address': '10.177.72.138'}, 'sample_rate': 16000, 'checkpoint': './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt', 'lang_dir': PosixPath('tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char'), 'method': 'greedy_search', 'sound_files': ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav'], 'beam_size': 4, 'context_size': 2, 'max_sym_per_frame': 3, 'blank_id': 0, 'vocab_size': 4336}
2022-03-03 15:35:27,027 INFO [pretrained.py:248] About to create model
2022-03-03 15:35:36,878 INFO [pretrained.py:257] Constructing Fbank computer
2022-03-03 15:35:36,880 INFO [pretrained.py:267] Reading sound files: ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav']
2022-03-03 15:35:36,891 INFO [pretrained.py:273] Decoding started
/ceph-fj/fangjun/open-source-2/icefall-aishell/egs/aishell/ASR/transducer_stateless_modified/conformer.py:113: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
lengths = ((x_lens - 1) // 2 - 1) // 2
2022-03-03 15:35:37,163 INFO [pretrained.py:320]
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav:
甚 至 出 现 交 易 几 乎 停 滞 的 情 况
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav:
一 二 线 城 市 虽 然 也 处 于 调 整 中
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav:
但 因 为 聚 集 了 过 多 公 共 资 源
2022-03-03 15:35:37,163 INFO [pretrained.py:322] Decoding Done
Beam search
^^^^^^^^^^^
The command to run beam search is given below:
.. code-block:: bash
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/pretrained.py \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
--method beam_search \
--beam-size 4 \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
The output is as follows:
.. code-block::
2022-03-03 15:39:09,285 INFO [pretrained.py:239] device: cuda:0
2022-03-03 15:39:09,708 INFO [lexicon.py:176] Loading pre-compiled tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char/Linv.pt
2022-03-03 15:39:09,759 INFO [pretrained.py:246] {'feature_dim': 80, 'encoder_out_dim': 512, 'subsampling_factor': 4, 'attention_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'vgg_frontend': False, 'env_info': {'k2-version': '1.13', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'f4fefe4882bc0ae59af951da3f47335d5495ef71', 'k2-git-date': 'Thu Feb 10 15:16:02 2022', 'lhotse-version': '1.0.0.dev+missing.version.file', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '50d2281-clean', 'icefall-git-date': 'Wed Mar 2 16:02:38 2022', 'icefall-path': '/ceph-fj/fangjun/open-source-2/icefall-aishell', 'k2-path': '/ceph-fj/fangjun/open-source-2/k2-multi-datasets/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-aishell/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-2-0815224919-75d558775b-mmnv8', 'IP address': '10.177.72.138'}, 'sample_rate': 16000, 'checkpoint': './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt', 'lang_dir': PosixPath('tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char'), 'method': 'beam_search', 'sound_files': ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav'], 'beam_size': 4, 'context_size': 2, 'max_sym_per_frame': 3, 'blank_id': 0, 'vocab_size': 4336}
2022-03-03 15:39:09,760 INFO [pretrained.py:248] About to create model
2022-03-03 15:39:18,919 INFO [pretrained.py:257] Constructing Fbank computer
2022-03-03 15:39:18,922 INFO [pretrained.py:267] Reading sound files: ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav']
2022-03-03 15:39:18,929 INFO [pretrained.py:273] Decoding started
/ceph-fj/fangjun/open-source-2/icefall-aishell/egs/aishell/ASR/transducer_stateless_modified/conformer.py:113: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
lengths = ((x_lens - 1) // 2 - 1) // 2
2022-03-03 15:39:21,046 INFO [pretrained.py:320]
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav:
甚 至 出 现 交 易 几 乎 停 滞 的 情 况
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav:
一 二 线 城 市 虽 然 也 处 于 调 整 中
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav:
但 因 为 聚 集 了 过 多 公 共 资 源
2022-03-03 15:39:21,047 INFO [pretrained.py:322] Decoding Done
Modified Beam search
^^^^^^^^^^^^^^^^^^^^
The command to run modified beam search is given below:
.. code-block:: bash
$ cd egs/aishell/ASR
$ ./transducer_stateless_modified/pretrained.py \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
--method modified_beam_search \
--beam-size 4 \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
The output is as follows:
.. code-block::
2022-03-03 15:41:23,319 INFO [pretrained.py:239] device: cuda:0
2022-03-03 15:41:23,798 INFO [lexicon.py:176] Loading pre-compiled tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char/Linv.pt
2022-03-03 15:41:23,831 INFO [pretrained.py:246] {'feature_dim': 80, 'encoder_out_dim': 512, 'subsampling_factor': 4, 'attention_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'vgg_frontend': False, 'env_info': {'k2-version': '1.13', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'f4fefe4882bc0ae59af951da3f47335d5495ef71', 'k2-git-date': 'Thu Feb 10 15:16:02 2022', 'lhotse-version': '1.0.0.dev+missing.version.file', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '50d2281-clean', 'icefall-git-date': 'Wed Mar 2 16:02:38 2022', 'icefall-path': '/ceph-fj/fangjun/open-source-2/icefall-aishell', 'k2-path': '/ceph-fj/fangjun/open-source-2/k2-multi-datasets/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-aishell/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-2-0815224919-75d558775b-mmnv8', 'IP address': '10.177.72.138'}, 'sample_rate': 16000, 'checkpoint': './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt', 'lang_dir': PosixPath('tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char'), 'method': 'modified_beam_search', 'sound_files': ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav'], 'beam_size': 4, 'context_size': 2, 'max_sym_per_frame': 3, 'blank_id': 0, 'vocab_size': 4336}
2022-03-03 15:41:23,831 INFO [pretrained.py:248] About to create model
2022-03-03 15:41:32,214 INFO [pretrained.py:257] Constructing Fbank computer
2022-03-03 15:41:32,215 INFO [pretrained.py:267] Reading sound files: ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav']
2022-03-03 15:41:32,220 INFO [pretrained.py:273] Decoding started
/ceph-fj/fangjun/open-source-2/icefall-aishell/egs/aishell/ASR/transducer_stateless_modified/conformer.py:113: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
lengths = ((x_lens - 1) // 2 - 1) // 2
/ceph-fj/fangjun/open-source-2/icefall-aishell/egs/aishell/ASR/transducer_stateless_modified/beam_search.py:402: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
topk_hyp_indexes = topk_indexes // logits.size(-1)
2022-03-03 15:41:32,583 INFO [pretrained.py:320]
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav:
甚 至 出 现 交 易 几 乎 停 滞 的 情 况
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav:
一 二 线 城 市 虽 然 也 处 于 调 整 中
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav:
但 因 为 聚 集 了 过 多 公 共 资 源
2022-03-03 15:41:32,583 INFO [pretrained.py:322] Decoding Done
Colab notebook
--------------
We provide a colab notebook for this recipe showing how to use a pre-trained model to
transcribe sound files.
|aishell asr stateless modified transducer colab notebook|
.. |aishell asr stateless modified transducer colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/12jpTxJB44vzwtcmJl2DTdznW0OawPb9H?usp=sharing

View File

@ -10,12 +10,10 @@ We may add recipes for other tasks as well in the future.
.. Other recipes are listed in a alphabetical order. .. Other recipes are listed in a alphabetical order.
.. toctree:: .. toctree::
:maxdepth: 3 :maxdepth: 2
:caption: Table of Contents
yesno aishell/index
librispeech/index
librispeech timit/index
yesno/index
aishell
timit

View File

@ -1,10 +0,0 @@
LibriSpeech
===========
We provide the following models for the LibriSpeech dataset:
.. toctree::
:maxdepth: 2
librispeech/tdnn_lstm_ctc
librispeech/conformer_ctc

View File

@ -0,0 +1,8 @@
LibriSpeech
===========
.. toctree::
:maxdepth: 1
tdnn_lstm_ctc
conformer_ctc

View File

@ -1,10 +0,0 @@
TIMIT
===========
We provide the following models for the TIMIT dataset:
.. toctree::
:maxdepth: 2
timit/tdnn_lstm_ctc
timit/tdnn_ligru_ctc

View File

@ -0,0 +1,9 @@
TIMIT
=====
.. toctree::
:maxdepth: 1
tdnn_ligru_ctc
tdnn_lstm_ctc

View File

@ -1,5 +1,5 @@
TDNN-LiGRU-CTC TDNN-LiGRU-CTC
============= ==============
This tutorial shows you how to run a TDNN-LiGRU-CTC model with the `TIMIT <https://data.deepai.org/timit.zip>`_ dataset. This tutorial shows you how to run a TDNN-LiGRU-CTC model with the `TIMIT <https://data.deepai.org/timit.zip>`_ dataset.

View File

Before

Width:  |  Height:  |  Size: 121 KiB

After

Width:  |  Height:  |  Size: 121 KiB

View File

@ -0,0 +1,7 @@
YesNo
=====
.. toctree::
:maxdepth: 1
tdnn

View File

@ -1,5 +1,5 @@
yesno TDNN-CTC
===== ========
This page shows you how to run the `yesno <https://www.openslr.org/1>`_ recipe. It contains: This page shows you how to run the `yesno <https://www.openslr.org/1>`_ recipe. It contains:
@ -145,7 +145,7 @@ In ``tdnn/exp``, you will find the following files:
Note there is a URL in the above output, click it and you will see Note there is a URL in the above output, click it and you will see
the following screenshot: the following screenshot:
.. figure:: images/yesno-tdnn-tensorboard-log.png .. figure:: images/tdnn-tensorboard-log.png
:width: 600 :width: 600
:alt: TensorBoard screenshot :alt: TensorBoard screenshot
:align: center :align: center

View File

@ -1,3 +1,20 @@
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell.html> # Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell/index.html>
for how to run models in this recipe. for how to run models in this recipe.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
| `transducer_stateless` | Conformer | Embedding + Conv1d | with `k2.rnnt_loss` |
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

View File

@ -1,12 +1,198 @@
## Results ## Results
### Aishell training result(Transducer-stateless)
#### 2022-03-01
[./transducer_stateless_modified-2](./transducer_stateless_modified-2)
Stateless transducer + modified transducer + using [aidatatang_200zh](http://www.openslr.org/62/) as extra training data.
| | test |comment |
|------------------------|------|----------------------------------------------------------------|
| greedy search | 4.94 |--epoch 89, --avg 38, --max-duration 100, --max-sym-per-frame 1 |
| modified beam search | 4.68 |--epoch 89, --avg 38, --max-duration 100 --beam-size 4 |
The training commands are:
```bash
cd egs/aishell/ASR
./prepare.sh --stop-stage 6
./prepare_aidatatang_200zh.sh
export CUDA_VISIBLE_DEVICES="0,1,2"
./transducer_stateless_modified-2/train.py \
--world-size 3 \
--num-epochs 90 \
--start-epoch 0 \
--exp-dir transducer_stateless_modified-2/exp-2 \
--max-duration 250 \
--lr-factor 2.0 \
--context-size 2 \
--modified-transducer-prob 0.25 \
--datatang-prob 0.2
```
The tensorboard log is available at
<https://tensorboard.dev/experiment/oG72ZlWaSGua6fXkcGRRjA/>
The commands for decoding are
```bash
# greedy search
for epoch in 89; do
for avg in 38; do
./transducer_stateless_modified-2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_modified-2/exp-2 \
--max-duration 100 \
--context-size 2 \
--decoding-method greedy_search \
--max-sym-per-frame 1
done
done
# modified beam search
for epoch in 89; do
for avg in 38; do
./transducer_stateless_modified-2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_modified-2/exp-2 \
--max-duration 100 \
--context-size 2 \
--decoding-method modified_beam_search \
--beam-size 4
done
done
```
You can find a pre-trained model, decoding logs, and decoding results at
<https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01>
#### 2022-03-01
[./transducer_stateless_modified](./transducer_stateless_modified)
Stateless transducer + modified transducer.
| | test |comment |
|------------------------|------|----------------------------------------------------------------|
| greedy search | 5.22 |--epoch 64, --avg 33, --max-duration 100, --max-sym-per-frame 1 |
| modified beam search | 5.02 |--epoch 64, --avg 33, --max-duration 100 --beam-size 4 |
The training commands are:
```bash
cd egs/aishell/ASR
./prepare.sh --stop-stage 6
export CUDA_VISIBLE_DEVICES="0,1,2"
./transducer_stateless_modified/train.py \
--world-size 3 \
--num-epochs 90 \
--start-epoch 0 \
--exp-dir transducer_stateless_modified/exp-4 \
--max-duration 250 \
--lr-factor 2.0 \
--context-size 2 \
--modified-transducer-prob 0.25
```
The tensorboard log is available at
<https://tensorboard.dev/experiment/C27M8YxRQCa1t2XglTqlWg/>
The commands for decoding are
```bash
# greedy search
for epoch in 64; do
for avg in 33; do
./transducer_stateless_modified/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_modified/exp-4 \
--max-duration 100 \
--context-size 2 \
--decoding-method greedy_search \
--max-sym-per-frame 1
done
done
# modified beam search
for epoch in 64; do
for avg in 33; do
./transducer_stateless_modified/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_modified/exp-4 \
--max-duration 100 \
--context-size 2 \
--decoding-method modified_beam_search \
--beam-size 4
done
done
```
You can find a pre-trained model, decoding logs, and decoding results at
<https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01>
#### 2022-2-19
(Duo Ma): The tensorboard log for training is available at https://tensorboard.dev/experiment/25PmX3MxSVGTdvIdhOwllw/#scalars
You can find a pretrained model by visiting https://huggingface.co/shuanguanma/icefall_aishell_transducer_stateless_context_size2_epoch60_2022_2_19
| | test |comment |
|---------------------------|------|-----------------------------------------|
| greedy search | 5.4 |--epoch 59, --avg 10, --max-duration 100 |
| beam search | 5.05|--epoch 59, --avg 10, --max-duration 100 |
You can use the following commands to reproduce our results:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python3 ./transducer_stateless/train.py \
--world-size 4 \
--num-epochs 60 \
--start-epoch 0 \
--exp-dir exp/transducer_stateless_context_size2 \
--max-duration 100 \
--lr-factor 2.5 \
--context-size 2
lang_dir=data/lang_char
dir=exp/transducer_stateless_context_size2
python3 ./transducer_stateless/decode.py \
--epoch 59 \
--avg 10 \
--exp-dir $dir \
--lang-dir $lang_dir \
--decoding-method greedy_search \
--context-size 2 \
--max-sym-per-frame 3
lang_dir=data/lang_char
dir=exp/transducer_stateless_context_size2
python3 ./transducer_stateless/decode.py \
--epoch 59 \
--avg 10 \
--exp-dir $dir \
--lang-dir $lang_dir \
--decoding-method beam_search \
--context-size 2 \
--max-sym-per-frame 3
```
### Aishell training results (Transducer-stateless) ### Aishell training results (Transducer-stateless)
#### 2021-12-29 #### 2022-02-18
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/sPEDmAQ3QcWuDAWGiKprVg/> (Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/k3QL6QMhRbCwCKYKM9po9w/>
And pretrained model is available at <https://huggingface.co/pfluo/icefall-aishell-transducer-stateless-char-2021-12-29>
||test| ||test|
|--|--| |--|--|
|CER| 5.7% | |CER| 5.05% |
You can use the following commands to reproduce our results: You can use the following commands to reproduce our results:
@ -16,17 +202,17 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
--bucketing-sampler True \ --bucketing-sampler True \
--world-size 8 \ --world-size 8 \
--lang-dir data/lang_char \ --lang-dir data/lang_char \
--num-epochs 40 \ --num-epochs 60 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp_char \ --exp-dir transducer_stateless/exp_rnnt_k2 \
--max-duration 160 \ --max-duration 80 \
--lr-factor 3 --lr-factor 3
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch 39 \ --epoch 59 \
--avg 10 \ --avg 10 \
--lang-dir data/lang_char \ --lang-dir data/lang_char \
--exp-dir transducer_stateless/exp_char \ --exp-dir transducer_stateless/exp_rnnt_k2 \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4

View File

@ -121,6 +121,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser return parser
@ -555,7 +562,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
fix_random_seed(42) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -618,6 +625,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate cur_lr = optimizer._rate

View File

@ -124,6 +124,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser return parser
@ -546,7 +553,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
fix_random_seed(42) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -613,6 +620,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate cur_lr = optimizer._rate

View File

@ -1,156 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_hlg.py

View File

@ -1,110 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_musan(num_mel_bins: int = 80):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"music",
"speech",
"noise",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None
musan_cuts_path = output_dir / "cuts_musan.json.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(
part["recordings"] for part in manifests.values()
)
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_musan",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomHdf5Writer,
)
)
musan_cuts.to_json(musan_cuts_path)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_musan(num_mel_bins=args.num_mel_bins)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -1,107 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
"""
Convert a transcript file containing words to a corpus file containing tokens
for LM training with the help of a lexicon.
If the lexicon contains phones, the resulting LM will be a phone LM; If the
lexicon contains word pieces, the resulting LM will be a word piece LM.
If a word has multiple pronunciations, the one that appears first in the lexicon
is kept; others are removed.
If the input transcript is:
hello zoo world hello
world zoo
foo zoo world hellO
and if the lexicon is
<UNK> SPN
hello h e l l o 2
hello h e l l o
world w o r l d
zoo z o o
Then the output is
h e l l o 2 z o o w o r l d h e l l o 2
w o r l d z o o
SPN z o o w o r l d SPN
"""
import argparse
from pathlib import Path
from typing import Dict, List
from generate_unique_lexicon import filter_multiple_pronunications
from icefall.lexicon import read_lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transcript",
type=str,
help="The input transcript file."
"We assume that the transcript file consists of "
"lines. Each line consists of space separated words.",
)
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
parser.add_argument(
"--oov", type=str, default="<UNK>", help="The OOV word."
)
return parser.parse_args()
def process_line(
lexicon: Dict[str, List[str]], line: str, oov_token: str
) -> None:
"""
Args:
lexicon:
A dict containing pronunciations. Its keys are words and values
are pronunciations (i.e., tokens).
line:
A line of transcript consisting of space(s) separated words.
oov_token:
The pronunciation of the oov word if a word in `line` is not present
in the lexicon.
Returns:
Return None.
"""
s = ""
words = line.strip().split()
for i, w in enumerate(words):
tokens = lexicon.get(w, oov_token)
s += " ".join(tokens)
s += " "
print(s.strip())
def main():
args = get_args()
assert Path(args.lexicon).is_file()
assert Path(args.transcript).is_file()
assert len(args.oov) > 0
# Only the first pronunciation of a word is kept
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
lexicon = dict(lexicon)
assert args.oov in lexicon
oov_token = lexicon[args.oov]
with open(args.transcript) as f:
for line in f:
process_line(lexicon=lexicon, line=line, oov_token=oov_token)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py

View File

@ -0,0 +1,196 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in transducer_stateless/train.py
for usage.
"""
from lhotse import load_manifest
def main():
# path = "./data/fbank/cuts_train.json.gz"
# path = "./data/fbank/cuts_test.json.gz"
# path = "./data/fbank/cuts_dev.json.gz"
# path = "./data/fbank/aidatatang_200zh/cuts_train_raw.jsonl.gz"
# path = "./data/fbank/aidatatang_200zh/cuts_test_raw.jsonl.gz"
path = "./data/fbank/aidatatang_200zh/cuts_dev_raw.jsonl.gz"
cuts = load_manifest(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
## train (after speed perturb)
Cuts count: 360294
Total duration (hours): 455.6
Speech duration (hours): 455.6 (100.0%)
***
Duration statistics (seconds):
mean 4.6
std 1.4
min 1.1
0.1% 1.8
0.5% 2.2
1% 2.3
5% 2.7
10% 3.0
10% 3.0
25% 3.5
50% 4.3
75% 5.4
90% 6.5
95% 7.2
99% 8.8
99.5% 9.4
99.9% 10.9
max 16.1
## test
Cuts count: 7176
Total duration (hours): 10.0
Speech duration (hours): 10.0 (100.0%)
***
Duration statistics (seconds):
mean 5.0
std 1.6
min 1.9
0.1% 2.2
0.5% 2.4
1% 2.6
5% 3.0
10% 3.2
10% 3.2
25% 3.8
50% 4.7
75% 5.9
90% 7.3
95% 8.2
99% 9.9
99.5% 10.7
99.9% 11.9
max 14.7
## dev
Cuts count: 14326
Total duration (hours): 18.1
Speech duration (hours): 18.1 (100.0%)
***
Duration statistics (seconds):
mean 4.5
std 1.3
min 1.6
0.1% 2.1
0.5% 2.3
1% 2.4
5% 2.9
10% 3.1
10% 3.1
25% 3.5
50% 4.3
75% 5.4
90% 6.4
95% 7.0
99% 8.4
99.5% 8.9
99.9% 10.3
max 12.5
## aidatatang_200zh (train)
Cuts count: 164905
Total duration (hours): 139.9
Speech duration (hours): 139.9 (100.0%)
***
Duration statistics (seconds):
mean 3.1
std 1.1
min 1.1
0.1% 1.5
0.5% 1.7
1% 1.8
5% 2.0
10% 2.1
10% 2.1
25% 2.3
50% 2.7
75% 3.4
90% 4.6
95% 5.4
99% 7.1
99.5% 7.8
99.9% 9.1
max 16.3
## aidatatang_200zh (test)
Cuts count: 48144
Total duration (hours): 40.2
Speech duration (hours): 40.2 (100.0%)
***
Duration statistics (seconds):
mean 3.0
std 1.1
min 0.9
0.1% 1.5
0.5% 1.8
1% 1.8
5% 2.0
10% 2.1
10% 2.1
25% 2.3
50% 2.6
75% 3.4
90% 4.4
95% 5.2
99% 6.9
99.5% 7.5
99.9% 9.0
max 21.8
## aidatatang_200zh (dev)
Cuts count: 24216
Total duration (hours): 20.2
Speech duration (hours): 20.2 (100.0%)
***
Duration statistics (seconds):
mean 3.0
std 1.0
min 1.2
0.1% 1.6
0.5% 1.7
1% 1.8
5% 2.0
10% 2.1
10% 2.1
25% 2.3
50% 2.7
75% 3.4
90% 4.4
95% 5.1
99% 6.7
99.5% 7.3
99.9% 8.8
max 11.3
"""

View File

@ -1,100 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file takes as input a lexicon.txt and output a new lexicon,
in which each word has a unique pronunciation.
The way to do this is to keep only the first pronunciation of a word
in lexicon.txt.
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
from icefall.lexicon import read_lexicon, write_lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
This file will generate a new file uniq_lexicon.txt
in it.
""",
)
return parser.parse_args()
def filter_multiple_pronunications(
lexicon: List[Tuple[str, List[str]]]
) -> List[Tuple[str, List[str]]]:
"""Remove multiple pronunciations of words from a lexicon.
If a word has more than one pronunciation in the lexicon, only
the first one is kept, while other pronunciations are removed
from the lexicon.
Args:
lexicon:
The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
where "p1, p2, ..., pn" are the pronunciations of the "word".
Returns:
Return a new lexicon where each word has a unique pronunciation.
"""
seen = set()
ans = []
for word, tokens in lexicon:
if word in seen:
continue
seen.add(word)
ans.append((word, tokens))
return ans
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
in_lexicon = read_lexicon(lexicon_filename)
out_lexicon = filter_multiple_pronunications(in_lexicon)
write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/generate_unique_lexicon.py

View File

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
def preprocess_aidatatang_200zh():
src_dir = Path("data/manifests/aidatatang_200zh")
output_dir = Path("data/fbank/aidatatang_200zh")
output_dir.mkdir(exist_ok=True, parents=True)
dataset_parts = (
"train",
"test",
"dev",
)
logging.info("Loading manifest")
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
)
assert len(manifests) > 0
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
for sup in m["supervisions"]:
sup.custom = {"origin": "aidatatang_200zh"}
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
def main():
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_aidatatang_200zh()
if __name__ == "__main__":
main()

View File

@ -48,8 +48,9 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM" log "stage -1: Download LM"
# We assume that you have installed the git-lfs, if not, you could install it # We assume that you have installed the git-lfs, if not, you could install it
# using: `sudo apt-get install git-lfs && git-lfs install` # using: `sudo apt-get install git-lfs && git-lfs install`
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
fi
fi fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
@ -87,28 +88,41 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare aishell manifest" log "Stage 1: Prepare aishell manifest"
# We assume that you have downloaded the aishell corpus # We assume that you have downloaded the aishell corpus
# to $dl_dir/aishell # to $dl_dir/aishell
mkdir -p data/manifests if [ ! -f data/manifests/.aishell_manifests.done ]; then
lhotse prepare aishell -j $nj $dl_dir/aishell data/manifests mkdir -p data/manifests
lhotse prepare aishell $dl_dir/aishell data/manifests
touch data/manifests/.aishell_manifests.done
fi
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest" log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus # We assume that you have downloaded the musan corpus
# to data/musan # to data/musan
mkdir -p data/manifests if [ ! -f data/manifests/.musan_manifests.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests log "It may take 6 minutes"
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_manifests.done
fi
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell" log "Stage 3: Compute fbank for aishell"
mkdir -p data/fbank if [ ! -f data/fbank/.aishell.done ]; then
./local/compute_fbank_aishell.py mkdir -p data/fbank
./local/compute_fbank_aishell.py
touch data/fbank/.aishell.done
fi
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan" log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank if [ ! -f data/fbank/.msuan.done ]; then
./local/compute_fbank_musan.py mkdir -p data/fbank
./local/compute_fbank_musan.py
touch data/fbank/.msuan.done
fi
fi fi
lang_phone_dir=data/lang_phone lang_phone_dir=data/lang_phone
@ -134,7 +148,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid
find $dl_dir/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid find $dl_dir/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text |
cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
fi fi
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then

View File

@ -0,0 +1,59 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/aidatatang_200zh
# You can find "corpus" and "transcript" inside it.
# You can download it at
# https://openslr.org/62/
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
if [ ! -f $dl_dir/aidatatang_200zh/transcript/aidatatang_200_zh_transcript.txt ]; then
lhotse download aidatatang-200zh $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare manifest"
# We assume that you have downloaded the aidatatang_200zh corpus
# to $dl_dir/aidatatang_200zh
if [ ! -f data/manifests/aidatatang_200zh/.manifests.done ]; then
mkdir -p data/manifests/aidatatang_200zh
lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh
touch data/manifests/aidatatang_200zh/.manifests.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aidatatang_200zh"
if [ ! -f data/fbank/aidatatang_200zh/.fbank.done ]; then
mkdir -p data/fbank/aidatatang_200zh
lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh
touch data/fbank/aidatatang_200zh/.fbank.done
fi
fi

View File

@ -1,4 +1,5 @@
# Copyright 2021 Piotr Żelasko # Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -16,6 +17,7 @@
import argparse import argparse
import inspect
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
@ -210,10 +212,20 @@ class AishellAsrDataModule:
logging.info( logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}" f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
) )
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append( input_transforms.append(
SpecAugment( SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor, time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2, num_frame_masks=num_frame_masks,
features_mask_size=27, features_mask_size=27,
num_feature_masks=2, num_feature_masks=2,
frames_mask_size=100, frames_mask_size=100,

View File

@ -92,6 +92,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser return parser
@ -507,7 +514,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
fix_random_seed(42) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -557,6 +564,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch: if epoch > params.start_epoch:

View File

@ -31,7 +31,6 @@ from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from model import Transducer from model import Transducer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
@ -403,12 +402,9 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0] # params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
logging.info(params) logging.info(params)

View File

@ -14,15 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
@ -108,18 +102,13 @@ class Transducer(nn.Module):
# Note: y does not start with SOS # Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), ( y_padded = y_padded.to(torch.int64)
f"Current torchaudio version: {torchaudio.__version__}\n" boundary = torch.zeros(
"Please install a version >= 0.10.0" (x.size(0), 4), dtype=torch.int64, device=x.device
) )
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
loss = torchaudio.functional.rnnt_loss( loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary)
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
)
return loss return loss

View File

@ -129,6 +129,13 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser return parser
@ -534,7 +541,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
fix_random_seed(42) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -558,7 +565,7 @@ def run(rank, world_size, args):
oov="<unk>", oov="<unk>",
) )
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0] params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
logging.info(params) logging.info(params)
@ -611,6 +618,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate cur_lr = optimizer._rate

View File

@ -0,0 +1,59 @@
## Introduction
The decoder, i.e., the prediction network, is from
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
(Rnn-Transducer with Stateless Prediction Network)
Different from `../transducer_stateless_modified`, this folder
uses extra data, i.e., http://www.openslr.org/62/, during training.
You can use the following command to start the training:
```bash
cd egs/aishell/ASR
./prepare.sh --stop-stage 6
./prepare_aidatatang_200zh.sh
export CUDA_VISIBLE_DEVICES="0,1,2"
./transducer_stateless_modified-2/train.py \
--world-size 3 \
--num-epochs 90 \
--start-epoch 0 \
--exp-dir transducer_stateless_modified-2/exp-2 \
--max-duration 250 \
--lr-factor 2.0 \
--context-size 2 \
--modified-transducer-prob 0.25 \
--datatang-prob 0.2
```
To decode, you can use
```bash
for epoch in 89; do
for avg in 30 38; do
./transducer_stateless_modified-2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_modified-2/exp-2 \
--max-duration 100 \
--context-size 2 \
--decoding-method greedy_search \
--max-sym-per-frame 1
done
done
for epoch in 89; do
for avg in 38; do
./transducer_stateless_modified-2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_modified-2/exp-2 \
--max-duration 100 \
--context-size 2 \
--decoding-method modified_beam_search \
--beam-size 4
done
done
```

View File

@ -0,0 +1,53 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
class AIDatatang200zh:
def __init__(self, manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files::
- cuts_dev_raw.jsonl.gz
- cuts_train_raw.jsonl.gz
- cuts_test_raw.jsonl.gz
"""
self.manifest_dir = Path(manifest_dir)
def train_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_train_raw.jsonl.gz"
logging.info(f"About to get train cuts from {f}")
cuts_train = load_manifest(f)
return cuts_train
def valid_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_valid_raw.jsonl.gz"
logging.info(f"About to get valid cuts from {f}")
cuts_valid = load_manifest(f)
return cuts_valid
def test_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_test_raw.jsonl.gz"
logging.info(f"About to get test cuts from {f}")
cuts_test = load_manifest(f)
return cuts_test

View File

@ -0,0 +1,53 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
class AIShell:
def __init__(self, manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files::
- cuts_dev.json.gz
- cuts_train.json.gz
- cuts_test.json.gz
"""
self.manifest_dir = Path(manifest_dir)
def train_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_train.json.gz"
logging.info(f"About to get train cuts from {f}")
cuts_train = load_manifest(f)
return cuts_train
def valid_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_dev.json.gz"
logging.info(f"About to get valid cuts from {f}")
cuts_valid = load_manifest(f)
return cuts_valid
def test_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_test.json.gz"
logging.info(f"About to get test cuts from {f}")
cuts_test = load_manifest(f)
return cuts_test

View File

@ -0,0 +1,316 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from pathlib import Path
from typing import Optional
from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import (
BucketingSampler,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
SpecAugment,
)
from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class AsrDataModule:
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler "
"and DynamicBucketingSampler."
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available. Used only in dev/test CutSet",
)
def train_dataloaders(
self,
cuts_train: CutSet,
dynamic_bucketing: bool,
on_the_fly_feats: bool,
cuts_musan: Optional[CutSet] = None,
) -> DataLoader:
"""
Args:
cuts_train:
Cuts for training.
cuts_musan:
If not None, it is the cuts for mixing.
dynamic_bucketing:
True to use DynamicBucketingSampler;
False to use BucketingSampler.
on_the_fly_feats:
True to use OnTheFlyFeatures;
False to use PrecomputedFeatures.
"""
transforms = []
if cuts_musan is not None:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if on_the_fly_feats
else PrecomputedFeatures()
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if dynamic_bucketing:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=True,
)
else:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = BucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/beam_search.py

View File

@ -0,0 +1 @@
../transducer_stateless_modified/conformer.py

View File

@ -0,0 +1,491 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless_modified-2/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_stateless_modified/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_stateless_modified-2/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from aishell import AIShell
from asr_datamodule import AsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless_modified-2/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --decoding-method is beam_search "
"and modified_beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
else:
return {f"beam_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval()
model.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
asr_datamodule = AsrDataModule(args)
aishell = AIShell(manifest_dir=args.manifest_dir)
test_cuts = aishell.test_cuts()
test_dl = asr_datamodule.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../transducer_stateless_modified/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless_modified/encoder_interface.py

View File

@ -0,0 +1,246 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer_stateless_modified-2/export.py \
--exp-dir ./transducer_stateless_modified-2/exp \
--epoch 89 \
--avg 38
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer_stateless_modified-2/decode.py`,
you can do::
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/aishell/ASR
./transducer_stateless_modified-2/decode.py \
--exp-dir ./transducer_stateless_modified-2/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=20,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=Path,
default=Path("transducer_stateless_modified-2/exp"),
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main():
args = get_parser().parse_args()
assert args.jit is False, "torchscript support will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../transducer_stateless_modified/joiner.py

View File

@ -0,0 +1,163 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Optional
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
decoder_datatang: Optional[nn.Module] = None,
joiner_datatang: Optional[nn.Module] = None,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
decoder_datatang:
The decoder for the aidatatang_200zh dataset.
joiner_datatang:
The joiner for the aidatatang_200zh dataset.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
if decoder_datatang is not None:
assert hasattr(decoder_datatang, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
aishell: bool = True,
modified_transducer_prob: float = 0.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
modified_transducer_prob:
The probability to use modified transducer loss.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
if aishell:
decoder = self.decoder
joiner = self.joiner
else:
decoder = self.decoder_datatang
joiner = self.joiner_datatang
decoder_out = decoder(sos_y_padded)
# +1 here since a blank is prepended to each utterance.
logits = joiner(
encoder_out=encoder_out,
decoder_out=decoder_out,
encoder_out_len=x_lens,
decoder_out_len=y_lens + 1,
)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
# We don't put this `import` at the beginning of the file
# as it is required only in the training, not during the
# reference stage
import optimized_transducer
assert 0 <= modified_transducer_prob <= 1
if modified_transducer_prob == 0:
one_sym_per_frame = False
elif random.random() < modified_transducer_prob:
# random.random() returns a float in the range [0, 1)
one_sym_per_frame = True
else:
one_sym_per_frame = False
loss = optimized_transducer.transducer_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
one_sym_per_frame=one_sym_per_frame,
from_log_softmax=False,
)
return loss

View File

@ -0,0 +1,331 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# greedy search
./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
# beam search
./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
# modified beam search
./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import kaldifeat
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
return parser
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
"sample_rate": 16000,
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
hyps = []
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
for i in range(encoder_out.size(0)):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../conformer_ctc/subsampling.py

View File

@ -0,0 +1 @@
../transducer_stateless_modified/test_decoder.py

View File

@ -0,0 +1,875 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
# Copyright 2021 (Pingfeng Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./prepare.sh
./prepare_aidatatang_200zh.sh
export CUDA_VISIBLE_DEVICES="0,1,2"
./transducer_stateless_modified-2/train.py \
--world-size 3 \
--num-epochs 90 \
--start-epoch 0 \
--exp-dir transducer_stateless_modified-2/exp-2 \
--max-duration 250 \
--lr-factor 2.0 \
--context-size 2 \
--modified-transducer-prob 0.25 \
--datatang-prob 0.2
"""
import argparse
import logging
import random
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell
from asr_datamodule import AsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_stateless/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless_modified-2/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--lr-factor",
type=float,
default=5.0,
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
transducer_stateless/decode.py
""",
)
parser.add_argument(
"--datatang-prob",
type=float,
default=0.2,
help="The probability to select a batch from the "
"aidatatang_200zh dataset",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 800, # For the 100h subset, use 800
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
decoder_datatang = get_decoder_model(params)
joiner_datatang = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
decoder_datatang=decoder_datatang,
joiner_datatang=joiner_datatang,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def is_aishell(c: Cut) -> bool:
"""Return True if this cut is from the AIShell dataset.
Note:
During data preparation, we set the custom field in
the supervision segment of aidatatang_200zh to
dict(origin='aidatatang_200zh')
See ../local/process_aidatatang_200zh.py.
"""
return c.supervisions[0].custom is None
def compute_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
aishell = is_aishell(supervisions["cut"][0])
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(
x=feature,
x_lens=feature_lens,
y=y,
aishell=aishell,
modified_transducer_prob=params.modified_transducer_prob,
)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
datatang_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
datatang_train_dl:
Dataloader for the aidatatang_200zh training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
aishell_tot_loss = MetricsTracker()
datatang_tot_loss = MetricsTracker()
tot_loss = MetricsTracker()
# index 0: for LibriSpeech
# index 1: for GigaSpeech
# This sets the probabilities for choosing which datasets
dl_weights = [1 - params.datatang_prob, params.datatang_prob]
iter_aishell = iter(train_dl)
iter_datatang = iter(datatang_train_dl)
batch_idx = 0
while True:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_aishell if idx == 0 else iter_datatang
try:
batch = next(dl)
except StopIteration:
break
batch_idx += 1
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
aishell = is_aishell(batch["supervisions"]["cut"][0])
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
if aishell:
aishell_tot_loss = (
aishell_tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
prefix = "aishell" # for logging only
else:
datatang_tot_loss = (
datatang_tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
prefix = "datatang"
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"aishell_tot_loss[{aishell_tot_loss}], "
f"datatang_tot_loss[{datatang_tot_loss}], "
f"batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer,
f"train/current_{prefix}_",
params.batch_idx_train,
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train
)
datatang_tot_loss.write_summary(
tb_writer, "train/datatang_tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 12 seconds
return 1.0 <= c.duration <= 12.0
num_in_total = len(cuts)
cuts = cuts.filter(remove_short_and_long_utt)
num_left = len(cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
return cuts
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
seed = 42
fix_random_seed(seed)
rng = random.Random(seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
oov="<unk>",
)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
aishell = AIShell(manifest_dir=args.manifest_dir)
train_cuts = aishell.train_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts)
datatang = AIDatatang200zh(
manifest_dir=f"{args.manifest_dir}/aidatatang_200zh"
)
train_datatang_cuts = datatang.train_cuts()
train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
if args.enable_musan:
cuts_musan = load_manifest(
Path(args.manifest_dir) / "cuts_musan.json.gz"
)
else:
cuts_musan = None
asr_datamodule = AsrDataModule(args)
train_dl = asr_datamodule.train_dataloaders(
train_cuts,
dynamic_bucketing=False,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
dynamic_bucketing=True,
on_the_fly_feats=True,
cuts_musan=cuts_musan,
)
valid_cuts = aishell.valid_cuts()
valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
for dl in [train_dl, datatang_train_dl]:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
datatang_train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
datatang_train_dl=datatang_train_dl,
valid_dl=valid_dl,
rng=rng,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
assert 0 <= args.datatang_prob < 1, args.datatang_prob
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../transducer_stateless_modified/transformer.py

View File

@ -0,0 +1,21 @@
## Introduction
The decoder, i.e., the prediction network, is from
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
(Rnn-Transducer with Stateless Prediction Network)
You can use the following command to start the training:
```bash
cd egs/aishell/ASR
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./transducer_stateless_modified/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless_modified/exp \
--max-duration 250 \
--lr-factor 2.5
```

View File

@ -0,0 +1 @@
../conformer_ctc/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/conformer.py

View File

@ -0,0 +1,486 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless_modified/decode.py \
--epoch 64 \
--avg 33 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_stateless_modified/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_stateless_modified/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless_modified/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --decoding-method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
else:
return {f"beam_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,246 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer_stateless_modified/export.py \
--exp-dir ./transducer_stateless_modified/exp \
--epoch 64 \
--avg 33
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer_stateless_modified/decode.py`,
you can do::
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/aishell/ASR
./transducer_stateless_modified/decode.py \
--exp-dir ./transducer_stateless_modified/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=20,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=Path,
default=Path("transducer_stateless_modified/exp"),
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main():
args = get_parser().parse_args()
assert args.jit is False, "torchscript support will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/model.py

View File

@ -0,0 +1,331 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# greedy search
./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
# beam search
./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
# modified beam search
./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import kaldifeat
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
return parser
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
"sample_rate": 16000,
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
hyps = []
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
for i in range(encoder_out.size(0)):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../conformer_ctc/subsampling.py

View File

@ -0,0 +1,58 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/aishell/ASR
python ./transducer_stateless/test_decoder.py
"""
import torch
from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
embedding_dim = 128
context_size = 4
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
context_size=context_size,
)
N = 100
U = 20
x = torch.randint(low=0, high=vocab_size, size=(N, U))
y = decoder(x)
assert y.shape == (N, U, embedding_dim)
# for inference
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
y = decoder(x, need_pad=False)
assert y.shape == (N, 1, embedding_dim)
def main():
test_decoder()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,751 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
# Copyright 2021 (Pingfeng Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2"
./transducer_stateless_modified/train.py \
--world-size 3 \
--num-epochs 65 \
--start-epoch 0 \
--exp-dir transducer_stateless_modified/exp \
--max-duration 250 \
--lr-factor 2.0 \
--context-size 2 \
--modified-transducer-prob 0.25
"""
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_stateless/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless_modified/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--lr-factor",
type=float,
default=5.0,
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
transducer_stateless/decode.py
""",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 800,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(
x=feature,
x_lens=feature_lens,
y=y,
modified_transducer_prob=params.modified_transducer_prob,
)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
oov="<unk>",
)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
aishell = AishellAsrDataModule(args)
train_cuts = aishell.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 12 seconds
return 1.0 <= c.duration <= 12.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = aishell.train_dataloaders(train_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/transformer.py

View File

@ -1,7 +1,7 @@
# Introduction # Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech.html> Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech/index.html>
for how to run models in this recipe. for how to run models in this recipe.
# Transducers # Transducers
@ -9,11 +9,13 @@ for how to run models in this recipe.
There are various folders containing the name `transducer` in this folder. There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them. The following table lists the differences among them.
| | Encoder | Decoder | | | Encoder | Decoder | Comment |
|------------------------|-----------|--------------------| |---------------------------------------|-----------|--------------------|---------------------------------------------------|
| `transducer` | Conformer | LSTM | | `transducer` | Conformer | LSTM | |
| `transducer_stateless` | Conformer | Embedding + Conv1d | | `transducer_stateless` | Conformer | Embedding + Conv1d | |
| `transducer_lstm ` | LSTM | LSTM | | `transducer_lstm` | LSTM | LSTM | |
| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -0,0 +1,77 @@
# Results for train-clean-100
This page shows the WERs for test-clean/test-other using only
train-clean-100 subset as training data.
## Conformer encoder + embedding decoder
### 2022-02-21
Using commit `2332ba312d7ce72f08c7bac1e3312f7e3dd722dc`.
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|------------------------------------------|
| greedy search (max sym per frame 1) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
| greedy search (max sym per frame 2) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
| greedy search (max sym per frame 3) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
| modified beam search (beam size 4) | 6.31 | 16.3 | --epoch 57, --avg 17, --max-duration 100 |
The training command for reproducing is given below:
```bash
cd egs/librispeech/ASR/
./prepare.sh
./prepare_giga_speech.sh
export CUDA_VISIBLE_DEVICES="0,1"
./transducer_stateless_multi_datasets/train.py \
--world-size 2 \
--num-epochs 60 \
--start-epoch 0 \
--exp-dir transducer_stateless_multi_datasets/exp-100-2 \
--full-libri 0 \
--max-duration 300 \
--lr-factor 1 \
--bpe-model data/lang_bpe_500/bpe.model \
--modified-transducer-prob 0.25
--giga-prob 0.2
```
The decoding command is given below:
```bash
for epoch in 57; do
for avg in 17; do
for sym in 1 2 3; do
./transducer_stateless_multi_datasets/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_multi_datasets/exp-100-2 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--context-size 2 \
--max-sym-per-frame $sym
done
done
done
epoch=57
avg=17
./transducer_stateless_multi_datasets/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_multi_datasets/exp-100-2 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--context-size 2 \
--decoding-method modified_beam_search \
--beam-size 4
```
The tensorboard log is available at
<https://tensorboard.dev/experiment/qUEKzMnrTZmOz1EXPda9RA/>
A pre-trained model and decoding logs can be found at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21>

View File

@ -1,65 +1,304 @@
## Results ## Results
### LibriSpeech BPE training results (Transducer) ### LibriSpeech BPE training results (Pruned Transducer)
#### Conformer encoder + embedding decoder
Using commit `4c1b3665ee6efb935f4dd93a80ff0e154b13efb6`.
Conformer encoder + non-current decoder. The decoder Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2). contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim).
#### 2022-03-12
[pruned_transducer_stateless](./pruned_transducer_stateless)
Using commit `1603744469d167d848e074f2ea98c587153205fa`.
See <https://github.com/k2-fsa/icefall/pull/248>
The WERs are:
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|------------------------------------------|
| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
| greedy search (max sym per frame 2) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
| greedy search (max sym per frame 3) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
| modified beam search (beam size 4) | 2.56 | 6.27 | --epoch 42, --avg 11, --max-duration 100 |
| beam search (beam size 4) | 2.57 | 6.27 | --epoch 42, --avg 11, --max-duration 100 |
The decoding time for `test-clean` and `test-other` is given below:
(A V100 GPU with 32 GB RAM is used for decoding. Note: Not all GPU RAM is used during decoding.)
| decoding method | test-clean (seconds) | test-other (seconds)|
|---|---:|---:|
| greedy search (--max-sym-per-frame=1) | 160 | 159 |
| greedy search (--max-sym-per-frame=2) | 184 | 177 |
| greedy search (--max-sym-per-frame=3) | 210 | 213 |
| modified beam search (--beam-size 4)| 273 | 269 |
|beam search (--beam-size 4) | 2741 | 2221 |
We recommend you to use `modified_beam_search`.
Training command:
```bash
cd egs/librispeech/ASR/
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
. path.sh
./pruned_transducer_stateless/train.py \
--world-size 8 \
--num-epochs 60 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/WKRFY5fYSzaVBHahenpNlA/>
The command for decoding is:
```bash
epoch=42
avg=11
sym=1
# greedy search
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search \
--beam-size 4 \
--max-sym-per-frame $sym
# modified beam search
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
# beam search
# (not recommended)
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
```
You can find a pre-trained model, decoding logs, and decoding results at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>
#### 2022-02-18
[pruned_transducer_stateless](./pruned_transducer_stateless)
The WERs are The WERs are
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------| |---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 | | greedy search | 2.85 | 6.98 | --epoch 28, --avg 15, --max-duration 100 |
| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 |
The training command for reproducing is given below: The training command for reproducing is given below:
``` ```
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25 \
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/ejG7VpakRYePNNj6AbDEUw/#scalars>
The decoding command is:
```
epoch=28
avg=15
## greedy search
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless/exp \
--max-duration 100
```
### LibriSpeech BPE training results (Transducer)
#### Conformer encoder + embedding decoder
Conformer encoder + non-recurrent decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2).
See
- [./transducer_stateless](./transducer_stateless)
- [./transducer_stateless_multi_datasets](./transducer_stateless_multi_datasets)
##### 2022-03-01
Using commit `2332ba312d7ce72f08c7bac1e3312f7e3dd722dc`.
It uses [GigaSpeech](https://github.com/SpeechColab/GigaSpeech)
as extra training data. 20% of the time it selects a batch from L subset of
GigaSpeech and 80% of the time it selects a batch from LibriSpeech.
The WERs are
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|------------------------------------------|
| greedy search (max sym per frame 1) | 2.64 | 6.55 | --epoch 39, --avg 15, --max-duration 100 |
| modified beam search (beam size 4) | 2.61 | 6.46 | --epoch 39, --avg 15, --max-duration 100 |
The training command for reproducing is given below:
```bash
cd egs/librispeech/ASR/
./prepare.sh
./prepare_giga_speech.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless_multi_datasets/train.py \
--world-size 4 \
--num-epochs 40 \
--start-epoch 0 \
--exp-dir transducer_stateless_multi_datasets/exp-full-2 \
--full-libri 1 \
--max-duration 300 \
--lr-factor 5 \
--bpe-model data/lang_bpe_500/bpe.model \
--modified-transducer-prob 0.25 \
--giga-prob 0.2
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/xmo5oCgrRVelH9dCeOkYBg/>
The decoding command is:
```bash
epoch=39
avg=15
sym=1
# greedy search
./transducer_stateless_multi_datasets/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_multi_datasets/exp-full-2 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--context-size 2 \
--max-sym-per-frame $sym
# modified beam search
./transducer_stateless_multi_datasets/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless_multi_datasets/exp-full-2 \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--context-size 2 \
--decoding-method modified_beam_search \
--beam-size 4
```
You can find a pretrained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01>
##### 2022-02-07
Using commit `a8150021e01d34ecbd6198fe03a57eacf47a16f2`.
The WERs are
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|------------------------------------------|
| greedy search (max sym per frame 1) | 2.67 | 6.67 | --epoch 63, --avg 19, --max-duration 100 |
| greedy search (max sym per frame 2) | 2.67 | 6.67 | --epoch 63, --avg 19, --max-duration 100 |
| greedy search (max sym per frame 3) | 2.67 | 6.67 | --epoch 63, --avg 19, --max-duration 100 |
| modified beam search (beam size 4) | 2.67 | 6.57 | --epoch 63, --avg 19, --max-duration 100 |
The training command for reproducing is given below:
```
cd egs/librispeech/ASR/
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \ ./transducer_stateless/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 76 \ --num-epochs 76 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp-full \ --exp-dir transducer_stateless/exp-full \
--full-libri 1 \ --full-libri 1 \
--max-duration 250 \ --max-duration 300 \
--lr-factor 3 --lr-factor 5 \
--bpe-model data/lang_bpe_500/bpe.model \
--modified-transducer-prob 0.25
``` ```
The tensorboard training log can be found at The tensorboard training log can be found at
<https://tensorboard.dev/experiment/qGdqzHnxS0WJ695OXfZDzA/#scalars&_smoothingWeight=0> <https://tensorboard.dev/experiment/qgvWkbF2R46FYA6ZMNmOjA/#scalars>
The decoding command is: The decoding command is:
``` ```
epoch=71 epoch=63
avg=15 avg=19
## greedy search ## greedy search
./transducer_stateless/decode.py \ for sym in 1 2 3; do
--epoch $epoch \ ./transducer_stateless/decode.py \
--avg $avg \ --epoch $epoch \
--exp-dir transducer_stateless/exp-full \ --avg $avg \
--bpe-model ./data/lang_bpe_500/bpe.model \ --exp-dir transducer_stateless/exp-full \
--max-duration 100 --bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--max-sym-per-frame $sym
done
## modified beam search
## beam search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch $epoch \ --epoch $epoch \
--avg $avg \ --avg $avg \
--exp-dir transducer_stateless/exp-full \ --exp-dir transducer_stateless/exp-full \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --context-size 2 \
--decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
``` ```
You can find a pretrained model by visiting You can find a pretrained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10> <https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07>
#### Conformer encoder + LSTM decoder #### Conformer encoder + LSTM decoder

View File

@ -140,6 +140,13 @@ def get_parser():
help="The lr_factor for Noam optimizer", help="The lr_factor for Noam optimizer",
) )
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser return parser
@ -580,7 +587,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
fix_random_seed(42) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -601,14 +608,14 @@ def run(rank, world_size, args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
if "lang_bpe" in params.lang_dir: if "lang_bpe" in str(params.lang_dir):
graph_compiler = BpeCtcTrainingGraphCompiler( graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
device=device, device=device,
sos_token="<sos/eos>", sos_token="<sos/eos>",
eos_token="<sos/eos>", eos_token="<sos/eos>",
) )
elif "lang_phone" in params.lang_dir: elif "lang_phone" in str(params.lang_dir):
assert params.att_rate == 0, ( assert params.att_rate == 0, (
"Attention decoder training does not support phone lang dirs " "Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 " "at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
@ -650,9 +657,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
# Note: find_unused_parameters=True is needed in case we model = DDP(model, device_ids=[rank])
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),
@ -686,6 +691,7 @@ def run(rank, world_size, args):
) )
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate cur_lr = optimizer._rate

View File

@ -1,356 +0,0 @@
# Copyright 2021 Piotr Żelasko
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule):
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.",
)
group.add_argument(
"--feature-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset")
transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = [
SpecAugment(
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
cuts = self.test_cuts()
is_list = isinstance(cuts, list)
test_loaders = []
if not is_list:
cuts = [cuts]
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
)
test_loaders.append(test_dl)
if is_list:
return test_loaders
else:
return test_loaders[0]
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts

View File

@ -0,0 +1 @@
../conformer_ctc/asr_datamodule.py

View File

@ -109,6 +109,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser return parser
@ -673,7 +680,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
fix_random_seed(42) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -761,6 +768,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders() valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
if ( if (
params.batch_idx_train >= params.use_ali_until params.batch_idx_train >= params.use_ali_until

View File

@ -0,0 +1,123 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from pathlib import Path
from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
whitespace_pattern=re.compile(r"\s\s+"),
) -> str:
return whitespace_pattern.sub(" ", punct_pattern.sub("", utt))
def has_no_oov(
sup: SupervisionSegment,
oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"),
) -> bool:
return oov_pattern.search(sup.text) is None
def preprocess_giga_speech():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
dataset_parts = (
"DEV",
"TEST",
"XS",
"S",
"M",
"L",
"XL",
)
logging.info("Loading manifest (may take 4 minutes)")
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix="gigaspeech",
suffix="jsonl.gz",
)
assert manifests is not None
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
# Note this step makes the recipe different than LibriSpeech:
# We must filter out some utterances and remove punctuation
# to be consistent with Kaldi.
logging.info("Filtering OOV utterances from supervisions")
m["supervisions"] = m["supervisions"].filter(has_no_oov)
logging.info(f"Normalizing text in {partition}")
for sup in m["supervisions"]:
sup.text = normalize_text(sup.text)
sup.custom = {"origin": "giga"}
# Create long-recording cut manifests.
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
logging.info("About to split cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
def main():
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_giga_speech()
if __name__ == "__main__":
main()

View File

@ -60,8 +60,11 @@ log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download LM" log "Stage -1: Download LM"
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm mkdir -p $dl_dir/lm
./local/download_lm.py --out-dir=$dl_dir/lm if [ ! -e $dl_dir/lm/.done ]; then
./local/download_lm.py --out-dir=$dl_dir/lm
touch $dl_dir/lm/.done
fi
fi fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
@ -91,7 +94,10 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the LibriSpeech corpus # We assume that you have downloaded the LibriSpeech corpus
# to $dl_dir/LibriSpeech # to $dl_dir/LibriSpeech
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests if [ ! -e data/manifests/.librispeech.done ]; then
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
touch data/manifests/.librispeech.done
fi
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -99,19 +105,28 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# We assume that you have downloaded the musan corpus # We assume that you have downloaded the musan corpus
# to data/musan # to data/musan
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for librispeech" log "Stage 3: Compute fbank for librispeech"
mkdir -p data/fbank mkdir -p data/fbank
./local/compute_fbank_librispeech.py if [ ! -e data/fbank/.librispeech.done ]; then
./local/compute_fbank_librispeech.py
touch data/fbank/.librispeech.done
fi
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan" log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank mkdir -p data/fbank
./local/compute_fbank_musan.py if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then

View File

@ -0,0 +1,109 @@
#!/usr/bin/env bash
set -eou pipefail
nj=15
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/GigaSpeech
# You can find audio, dict, GigaSpeech.json inside it.
# You can apply for the download credentials by following
# https://github.com/SpeechColab/GigaSpeech#download
# Number of hours for GigaSpeech subsets
# XL 10k hours
# L 2.5k hours
# M 1k hours
# S 250 hours
# XS 10 hours
# DEV 12 hours
# Test 40 hours
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
[ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech
# If you have pre-downloaded it to /path/to/GigaSpeech,
# you can create a symlink
#
# ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech
#
if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then
# Check credentials.
if [ ! -f $dl_dir/password ]; then
echo -n "$0: Please apply for the download credentials by following"
echo -n "https://github.com/SpeechColab/GigaSpeech#dataset-download"
echo " and save it to $dl_dir/password."
exit 1;
fi
PASSWORD=`cat $dl_dir/password 2>/dev/null`
if [ -z "$PASSWORD" ]; then
echo "$0: Error, $dl_dir/password is empty."
exit 1;
fi
PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1`
if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then
echo "$0: Error, invalid $dl_dir/password."
exit 1;
fi
# Download XL, DEV and TEST sets by default.
lhotse download gigaspeech \
--subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
--host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare GigaSpeech manifest (may take 30 minutes)"
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
mkdir -p data/manifests
lhotse prepare gigaspeech \
--subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
-j $nj \
$dl_dir/GigaSpeech data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Preprocess GigaSpeech manifest"
if [ ! -f data/fbank/.preprocess_complete ]; then
log "It may take 2 hours for this stage"
python3 ./local/preprocess_gigaspeech.py
touch data/fbank/.preprocess_complete
fi
fi

Some files were not shown because too many files have changed in this diff Show More