Merge branch 'dev_swbd' of https://github.com/JinZr/icefall into dev_swbd

This commit is contained in:
JinZr 2023-08-12 14:30:52 +08:00
commit 58d9088010
171 changed files with 16949 additions and 579 deletions

View File

@ -21,9 +21,9 @@ tree $repo/
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt"
git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/Linv.pt"
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/cpu_jit.pt"

View File

@ -0,0 +1,45 @@
# see also
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
name: Build docker image
on:
workflow_dispatch:
concurrency:
group: build_docker-${{ github.ref }}
cancel-in-progress: true
jobs:
build-docker-image:
name: ${{ matrix.image }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps:
# refer to https://github.com/actions/checkout
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Rename
shell: bash
run: |
image=${{ matrix.image }}
mv -v ./docker/$image.dockerfile ./Dockerfile
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Build and push
uses: docker/build-push-action@v4
with:
context: .
file: ./Dockerfile
push: true
tags: k2fsa/icefall:${{ matrix.image }}

View File

@ -44,7 +44,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -119,5 +119,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-06-20
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-06-20
path: egs/aishell/ASR/pruned_transducer_stateless3/exp/

92
.github/workflows/run-docker-image.yml vendored Normal file
View File

@ -0,0 +1,92 @@
name: Run docker image
on:
workflow_dispatch:
concurrency:
group: run_docker_image-${{ github.ref }}
cancel-in-progress: true
jobs:
run-docker-image:
name: ${{ matrix.image }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps:
# refer to https://github.com/actions/checkout
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Run the build process with Docker
uses: addnab/docker-run-action@v3
with:
image: k2fsa/icefall:${{ matrix.image }}
shell: bash
run: |
uname -a
cat /etc/*release
nvcc --version
# For torch1.9.0-cuda10.2
export LD_LIBRARY_PATH=/usr/local/cuda-10.2/compat:$LD_LIBRARY_PATH
# For torch1.12.1-cuda11.3
export LD_LIBRARY_PATH=/usr/local/cuda-11.3/compat:$LD_LIBRARY_PATH
# For torch2.0.0-cuda11.7
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/compat:$LD_LIBRARY_PATH
which nvcc
cuda_dir=$(dirname $(which nvcc))
echo "cuda_dir: $cuda_dir"
find $cuda_dir -name libcuda.so*
echo "--------------------"
find / -name libcuda.so* 2>/dev/null
# for torch1.13.0-cuda11.6
if [ -e /opt/conda/lib/stubs/libcuda.so ]; then
cd /opt/conda/lib/stubs && ln -s libcuda.so libcuda.so.1 && cd -
export LD_LIBRARY_PATH=/opt/conda/lib/stubs:$LD_LIBRARY_PATH
fi
find / -name libcuda.so* 2>/dev/null
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
python3 --version
which python3
python3 -m pip list
echo "----------torch----------"
python3 -m torch.utils.collect_env
echo "----------k2----------"
python3 -c "import k2; print(k2.__file__)"
python3 -c "import k2; print(k2.__dev_version__)"
python3 -m k2.version
echo "----------lhotse----------"
python3 -c "import lhotse; print(lhotse.__file__)"
python3 -c "import lhotse; print(lhotse.__version__)"
echo "----------kaldifeat----------"
python3 -c "import kaldifeat; print(kaldifeat.__file__)"
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
echo "Test yesno recipe"
cd egs/yesno/ASR
./prepare.sh
./tdnn/train.py
./tdnn/decode.py

View File

@ -43,7 +43,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -122,5 +122,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/

View File

@ -43,7 +43,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -155,5 +155,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless-2022-03-12
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/

View File

@ -43,7 +43,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -174,12 +174,12 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
- name: Upload decoding results for pruned_transducer_stateless3
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/

View File

@ -43,7 +43,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -155,5 +155,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless5-2022-05-13
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless5-2022-05-13
path: egs/librispeech/ASR/pruned_transducer_stateless5/exp/

View File

@ -155,5 +155,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-2022-11-11
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-2022-11-11
path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/

View File

@ -155,5 +155,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless8-2022-11-14
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless8-2022-11-14
path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/

View File

@ -159,5 +159,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-2022-12-01
path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/

View File

@ -163,5 +163,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer_mmi-2022-12-08
path: egs/librispeech/ASR/zipformer_mmi/exp/

View File

@ -168,5 +168,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-streaming-2022-12-29
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-streaming-2022-12-29
path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-2022-12-15-stateless7-ctc-bs
name: run-librispeech-2023-01-29-stateless7-ctc-bs
# zipformer
on:
@ -34,7 +34,7 @@ on:
- cron: "50 15 * * *"
jobs:
run_librispeech_2022_12_15_zipformer_ctc_bs:
run_librispeech_2023_01_29_zipformer_ctc_bs:
if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
@ -124,7 +124,7 @@ jobs:
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh
.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh
- name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
@ -159,5 +159,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2023-01-29
path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/

View File

@ -151,5 +151,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-conformer_ctc3-2022-11-28
path: egs/librispeech/ASR/conformer_ctc3/exp/

View File

@ -26,7 +26,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false
@ -159,5 +159,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/

View File

@ -43,7 +43,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -153,5 +153,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/

View File

@ -43,7 +43,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -155,5 +155,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-06-26
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-06-26
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/

View File

@ -170,5 +170,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
path: egs/librispeech/ASR/zipformer/exp/

View File

@ -43,7 +43,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -155,5 +155,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless2-2022-04-19
path: egs/librispeech/ASR/transducer_stateless2/exp/

View File

@ -155,5 +155,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
path: egs/librispeech/ASR/zipformer/exp/

View File

@ -151,5 +151,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
path: egs/librispeech/ASR/zipformer/exp/

View File

@ -33,7 +33,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false

View File

@ -42,7 +42,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -154,5 +154,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/

View File

@ -42,7 +42,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -154,5 +154,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/

View File

@ -33,7 +33,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false

View File

@ -33,7 +33,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false

View File

@ -42,7 +42,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
@ -154,5 +154,5 @@ jobs:
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless-2022-02-07
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless-2022-02-07
path: egs/librispeech/ASR/transducer_stateless/exp/

View File

@ -33,7 +33,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
fail-fast: false

View File

@ -33,7 +33,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false

View File

@ -33,7 +33,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# os: [ubuntu-18.04, macos-10.15]
# os: [ubuntu-latest, macos-10.15]
# TODO: enable macOS for CPU testing
os: [ubuntu-latest]
python-version: [3.8]

View File

@ -35,9 +35,9 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ["3.8"]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.23.2.dev20221201"]
torch: ["1.13.0"]
torchaudio: ["0.13.0"]
k2-version: ["1.24.3.dev20230719"]
fail-fast: false
@ -66,14 +66,14 @@ jobs:
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.github.io/k2/cpu.html
pip install git+https://github.com/lhotse-speech/lhotse
# icefall requirements
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf==3.20.*
pip install kaldifst
pip install onnxruntime
pip install onnxruntime matplotlib
pip install -r requirements.txt
- name: Install graphviz
@ -83,13 +83,6 @@ jobs:
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Install graphviz
if: startsWith(matrix.os, 'macos')
shell: bash
run: |
python3 -m pip install -qq graphviz
brew install -q graphviz
- name: Run tests
if: startsWith(matrix.os, 'ubuntu')
run: |
@ -129,40 +122,10 @@ jobs:
cd ../transducer_lstm
pytest -v -s
- name: Run tests
if: startsWith(matrix.os, 'macos')
run: |
ls -lh
export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
echo "lib_path: $lib_path"
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
pytest -v -s ./test
# run tests for conformer ctc
cd egs/librispeech/ASR/conformer_ctc
cd ../zipformer
pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
# cd ../transducer
# pytest -v -s
cd ../transducer_stateless2
pytest -v -s
cd ../transducer_lstm
pytest -v -s
- uses: actions/upload-artifact@v2
with:
path: egs/librispeech/ASR/zipformer/swoosh.pdf
name: swoosh.pdf

View File

@ -1,5 +1,20 @@
# icefall dockerfile
## Download from dockerhub
You can find pre-built docker image for icefall at the following address:
<https://hub.docker.com/r/k2fsa/icefall/tags>
Example usage:
```bash
docker run --gpus all --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
```
## Build from dockerfile
2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
ARG K2_VERSION="1.24.3.dev20230725+cuda11.3.torch1.12.1"
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.3.torch1.12.1"
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
\
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,72 @@
FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
ARG K2_VERSION="1.24.3.dev20230725+cuda11.6.torch1.13.0"
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.6.torch1.13.0"
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
\
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
ENV LD_LIBRARY_PATH /opt/conda/lib/stubs:$LD_LIBRARY_PATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,86 @@
FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0"
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda10.2.torch1.9.0"
ARG TORCHAUDIO_VERSION="0.9.0"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
# see https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/
RUN rm /etc/apt/sources.list.d/cuda.list && \
rm /etc/apt/sources.list.d/nvidia-ml.list && \
apt-key del 7fa2af80
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb && \
dpkg -i cuda-keyring_1.0-1_all.deb && \
rm -v cuda-keyring_1.0-1_all.deb && \
apt-get update && \
rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip uninstall -y tqdm && \
pip install -U --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
\
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz \
tqdm>=4.63.0
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
ARG K2_VERSION="1.24.3.dev20230718+cuda11.7.torch2.0.0"
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.7.torch2.0.0"
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
\
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -86,7 +86,13 @@ rst_epilog = """
.. _git-lfs: https://git-lfs.com/
.. _ncnn: https://github.com/tencent/ncnn
.. _LibriSpeech: https://www.openslr.org/12
.. _Gigaspeech: https://github.com/SpeechColab/GigaSpeech
.. _musan: http://www.openslr.org/17/
.. _ONNX: https://github.com/onnx/onnx
.. _onnxruntime: https://github.com/microsoft/onnxruntime
.. _torch: https://github.com/pytorch/pytorch
.. _torchaudio: https://github.com/pytorch/audio
.. _k2: https://github.com/k2-fsa/k2
.. _lhotse: https://github.com/lhotse-speech/lhotse
.. _yesno: https://www.openslr.org/1/
"""

View File

@ -0,0 +1,184 @@
.. _LODR:
LODR for RNN Transducer
=======================
As a type of E2E model, neural transducers are usually considered as having an internal
language model, which learns the language level information on the training corpus.
In real-life scenario, there is often a mismatch between the training corpus and the target corpus space.
This mismatch can be a problem when decoding for neural transducer models with language models as its internal
language can act "against" the external LM. In this tutorial, we show how to use
`Low-order Density Ratio <https://arxiv.org/abs/2203.16776>`_ to alleviate this effect to further improve the performance
of langugae model integration.
.. note::
This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply LODR to other recipes.
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
.. note::
For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However,
you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models
using that corpus.
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
to address the language information mismatch between the training
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
.. math::
\text{score}\left(y_u|\mathit{x},y\right) =
\log p\left(y_u|\mathit{x},y_{1:u-1}\right) +
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
\lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right)
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
shallow fusion is the subtraction of the source domain LM.
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
considered to be weak and can only capture low-level language information. Therefore, `LODR <https://arxiv.org/abs/2203.16776>`__ proposed to use
a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula
during decoding for transducer model:
.. math::
\text{score}\left(y_u|\mathit{x},y\right) =
\log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) +
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
As a bi-gram is much faster to evaluate, LODR is usually much faster.
Now, we will show you how to use LODR in ``icefall``.
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_.
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
The testing scenario here is intra-domain (we decode the model trained on `LibriSpeech`_ on `LibriSpeech`_ testing sets).
As the initial step, let's download the pre-trained model.
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search
The following WERs are achieved on test-clean and test-other:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 3.11 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.93 best for test-other
Then, we download the external language model and bi-gram LM that are necessary for LODR.
Note that the bi-gram is estimated on the LibriSpeech 960 hours' text.
.. code-block:: bash
$ # download the external LM
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ popd
$
$ # download the bi-gram
$ git lfs install
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
$ pushd data/lang_bpe_500
$ ln -s ../../librispeech_bigram/2gram.fst.txt .
$ popd
Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.42
$ LODR_scale=-0.24
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--beam-size 4 \
--exp-dir $exp_dir \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_LODR \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir $lm_dir \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500 \
--tokens-ngram 2 \
--ngram-lm-scale $LODR_scale
There are two extra arguments that need to be given when doing LODR. ``--tokens-ngram`` specifies the order of n-gram. As we
are using a bi-gram, we set it to 2. ``--ngram-lm-scale`` is the scale of the bi-gram, it should be a negative number
as we are subtracting the bi-gram's score during decoding.
The decoding results obtained with the above command are shown below:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 2.61 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 6.74 best for test-other
Recall that the lowest WER we obtained in :ref:`shallow_fusion` with beam size of 4 is ``2.77/7.08``, LODR
indeed **further improves** the WER. We can do even better if we increase ``--beam-size``:
.. list-table:: WER of LODR with different beam sizes
:widths: 25 25 50
:header-rows: 1
* - Beam size
- test-clean
- test-other
* - 4
- 2.61
- 6.74
* - 8
- 2.45
- 6.38
* - 12
- 2.4
- 6.23

View File

@ -0,0 +1,33 @@
Decoding with language models
=============================
This section describes how to use external langugage models
during decoding to improve the WER of transducer models.
The following decoding methods with external langugage models are available:
.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean)
:widths: 25 50
:header-rows: 1
* - Decoding method
- beam=4
* - ``modified_beam_search``
- Beam search (i.e. really n-best decoding, the "beam" is the value of n), similar to the original RNN-T paper. Note, this method does not use language model.
* - ``modified_beam_search_lm_shallow_fusion``
- As ``modified_beam_search``, but interpolate RNN-T scores with language model scores, also known as shallow fusion
* - ``modified_beam_search_LODR``
- As ``modified_beam_search_lm_shallow_fusion``, but subtract score of a (BPE-symbol-level) bigram backoff language model used as an approximation to the internal language model of RNN-T.
* - ``modified_beam_search_lm_rescore``
- As ``modified_beam_search``, but rescore the n-best hypotheses with external language model (e.g. RNNLM) and re-rank them.
* - ``modified_beam_search_lm_rescore_LODR``
- As ``modified_beam_search_lm_rescore``, but also subtract the score of a (BPE-symbol-level) bigram backoff language model during re-ranking.
.. toctree::
:maxdepth: 2
shallow-fusion
LODR
rescoring

View File

@ -0,0 +1,252 @@
.. _rescoring:
LM rescoring for Transducer
=================================
LM rescoring is a commonly used approach to incorporate external LM information. Unlike shallow-fusion-based
methods (see :ref:`shallow_fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search.
Rescoring is usually more efficient than shallow fusion since less computation is performed on the external LM.
In this tutorial, we will show you how to use external LM to rescore the n-best hypotheses decoded from neural transducer models in
`icefall <https://github.com/k2-fsa/icefall>`__.
.. note::
This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply shallow fusion to other recipes.
If you encounter any problems, please open an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
.. note::
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
.. HINT::
We recommend you to use a GPU for decoding.
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
As the initial step, let's download the pre-trained model.
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
As usual, we first test the model's performance without external LM. This can be done via the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search
The following WERs are achieved on test-clean and test-other:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 3.11 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.93 best for test-other
Now, we will try to improve the above WER numbers via external LM rescoring. We will download
a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
.. note::
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
.. code-block:: bash
$ # download the external LM
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ popd
With the RNNLM available, we can rescore the n-best hypotheses generated from `modified_beam_search`. Here,
`n` should be the number of beams, i.e ``--beam-size``. The command for LM rescoring is
as follows. Note that the ``--decoding-method`` is set to `modified_beam_search_lm_rescore` and ``--use-shallow-fusion``
is set to `False`.
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.43
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--beam-size 4 \
--exp-dir $exp_dir \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_rescore \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--use-shallow-fusion 0 \
--lm-type rnn \
--lm-exp-dir $lm_dir \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 2.93 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.6 best for test-other
Great! We made some improvements! Increasing the size of the n-best hypotheses will further boost the performance,
see the following table:
.. list-table:: WERs of LM rescoring with different beam sizes
:widths: 25 25 25
:header-rows: 1
* - Beam size
- test-clean
- test-other
* - 4
- 2.93
- 7.6
* - 8
- 2.67
- 7.11
* - 12
- 2.59
- 6.86
In fact, we can also apply LODR (see :ref:`LODR`) when doing LM rescoring. To do so, we need to
download the bi-gram required by LODR:
.. code-block:: bash
$ # download the bi-gram
$ git lfs install
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
$ pushd data/lang_bpe_500
$ ln -s ../../librispeech_bigram/2gram.arpa .
$ popd
Then we can performn LM rescoring + LODR by changing the decoding method to `modified_beam_search_lm_rescore_LODR`.
.. note::
This decoding method requires the dependency of `kenlm <https://github.com/kpu/kenlm>`_. You can install it
via this command: `pip install https://github.com/kpu/kenlm/archive/master.zip`.
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.43
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--beam-size 4 \
--exp-dir $exp_dir \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_rescore_LODR \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--use-shallow-fusion 0 \
--lm-type rnn \
--lm-exp-dir $lm_dir \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500
You should see the following WERs after executing the commands above:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 2.9 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.57 best for test-other
It's slightly better than LM rescoring. If we further increase the beam size, we will see
further improvements from LM rescoring + LODR:
.. list-table:: WERs of LM rescoring + LODR with different beam sizes
:widths: 25 25 25
:header-rows: 1
* - Beam size
- test-clean
- test-other
* - 4
- 2.9
- 7.57
* - 8
- 2.63
- 7.04
* - 12
- 2.52
- 6.73
As mentioned earlier, LM rescoring is usually faster than shallow-fusion based methods.
Here, we benchmark the WERs and decoding speed of them:
.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean)
:widths: 25 25 25 25
:header-rows: 1
* - Decoding method
- beam=4
- beam=8
- beam=12
* - ``modified_beam_search``
- 3.11/7.93; 132s
- 3.1/7.95; 177s
- 3.1/7.96; 210s
* - ``modified_beam_search_lm_shallow_fusion``
- 2.77/7.08; 262s
- 2.62/6.65; 352s
- 2.58/6.65; 488s
* - ``modified_beam_search_LODR``
- 2.61/6.74; 400s
- 2.45/6.38; 610s
- 2.4/6.23; 870s
* - ``modified_beam_search_lm_rescore``
- 2.93/7.6; 156s
- 2.67/7.11; 203s
- 2.59/6.86; 255s
* - ``modified_beam_search_lm_rescore_LODR``
- 2.9/7.57; 160s
- 2.63/7.04; 203s
- 2.52/6.73; 263s
.. note::
Decoding is performed with a single 32G V100, we set ``--max-duration`` to 600.
Decoding time here is only for reference and it may vary.

View File

@ -0,0 +1,176 @@
.. _shallow_fusion:
Shallow fusion for Transducer
=================================
External language models (LM) are commonly used to improve WERs for E2E ASR models.
This tutorial shows you how to perform ``shallow fusion`` with an external LM
to improve the word-error-rate of a transducer model.
.. note::
This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply shallow fusion to other recipes.
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
.. note::
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
.. HINT::
We recommend you to use a GPU for decoding.
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
As the initial step, let's download the pre-trained model.
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--exp-dir $exp_dir \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search
The following WERs are achieved on test-clean and test-other:
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 3.11 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.93 best for test-other
These are already good numbers! But we can further improve it by using shallow fusion with external LM.
Training a language model usually takes a long time, we can download a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
.. code-block:: bash
$ # download the external LM
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ popd
.. note::
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
To use shallow fusion for decoding, we can execute the following command:
.. code-block:: bash
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.29
$ ./pruned_transducer_stateless7_streaming/decode.py \
--epoch 99 \
--avg 1 \
--use-averaged-model False \
--beam-size 4 \
--exp-dir $exp_dir \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search_lm_shallow_fusion \
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
--use-shallow-fusion 1 \
--lm-type rnn \
--lm-exp-dir $lm_dir \
--lm-epoch 99 \
--lm-scale $lm_scale \
--lm-avg 1 \
--rnn-lm-embedding-dim 2048 \
--rnn-lm-hidden-dim 2048 \
--rnn-lm-num-layers 3 \
--lm-vocab-size 500
Note that we set ``--decoding-method modified_beam_search_lm_shallow_fusion`` and ``--use-shallow-fusion True``
to use shallow fusion. ``--lm-type`` specifies the type of neural LM we are going to use, you can either choose
between ``rnn`` or ``transformer``. The following three arguments are associated with the rnn:
- ``--rnn-lm-embedding-dim``
The embedding dimension of the RNN LM
- ``--rnn-lm-hidden-dim``
The hidden dimension of the RNN LM
- ``--rnn-lm-num-layers``
The number of RNN layers in the RNN LM.
The decoding result obtained with the above command are shown below.
.. code-block:: text
$ For test-clean, WER of different settings are:
$ beam_size_4 2.77 best for test-clean
$ For test-other, WER of different settings are:
$ beam_size_4 7.08 best for test-other
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
A few parameters can be tuned to further boost the performance of shallow fusion:
- ``--lm-scale``
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
- ``--beam-size``
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
Here, we also show how `--beam-size` effect the WER and decoding time:
.. list-table:: WERs and decoding time (on test-clean) of shallow fusion with different beam sizes
:widths: 25 25 25 25
:header-rows: 1
* - Beam size
- test-clean
- test-other
- Decoding time on test-clean (s)
* - 4
- 2.77
- 7.08
- 262
* - 8
- 2.62
- 6.65
- 352
* - 12
- 2.58
- 6.65
- 488
As we see, a larger beam size during shallow fusion improves the WER, but is also slower.

Binary file not shown.

After

Width:  |  Height:  |  Size: 356 KiB

View File

@ -0,0 +1,17 @@
.. _icefall_docker:
Docker
======
This section describes how to use pre-built docker images to run `icefall`_.
.. hint::
If you only have CPUs available, you can still use the pre-built docker
images.
.. toctree::
:maxdepth: 2
./intro.rst

View File

@ -0,0 +1,171 @@
Introduction
=============
We have pre-built docker images hosted at the following address:
`<https://hub.docker.com/repository/docker/k2fsa/icefall/general>`_
.. figure:: img/docker-hub.png
:width: 600
:align: center
You can find the ``Dockerfile`` at `<https://github.com/k2-fsa/icefall/tree/master/docker>`_.
We describe the following items in this section:
- How to view available tags
- How to download pre-built docker images
- How to run the `yesno`_ recipe within a docker container on ``CPU``
View available tags
===================
You can use the following command to view available tags:
.. code-block:: bash
curl -s 'https://registry.hub.docker.com/v2/repositories/k2fsa/icefall/tags/'|jq '."results"[]["name"]'
which will give you something like below:
.. code-block:: bash
"torch2.0.0-cuda11.7"
"torch1.12.1-cuda11.3"
"torch1.9.0-cuda10.2"
"torch1.13.0-cuda11.6"
.. hint::
Available tags will be updated when there are new releases of `torch`_.
Please select an appropriate combination of `torch`_ and CUDA.
Download a docker image
=======================
Suppose that you select the tag ``torch1.13.0-cuda11.6``, you can use
the following command to download it:
.. code-block:: bash
sudo docker image pull k2fsa/icefall:torch1.13.0-cuda11.6
Run a docker image with GPU
===========================
.. code-block:: bash
sudo docker run --gpus all --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
Run a docker image with CPU
===========================
.. code-block:: bash
sudo docker run --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
Run yesno within a docker container
===================================
After starting the container, the following interface is presented:
.. code-block:: bash
root@60c947eac59c:/workspace/icefall#
It shows the current user is ``root`` and the current working directory
is ``/workspace/icefall``.
Update the code
---------------
Please first run:
.. code-block:: bash
root@60c947eac59c:/workspace/icefall# git pull
so that your local copy contains the latest code.
Data preparation
----------------
Now we can use
.. code-block:: bash
root@60c947eac59c:/workspace/icefall# cd egs/yesno/ASR/
to switch to the ``yesno`` recipe and run
.. code-block:: bash
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./prepare.sh
.. hint::
If you are running without GPU, it may report the following error:
.. code-block:: bash
File "/opt/conda/lib/python3.9/site-packages/k2/__init__.py", line 23, in <module>
from _k2 import DeterminizeWeightPushingType
ImportError: libcuda.so.1: cannot open shared object file: No such file or directory
We can use the following command to fix it:
.. code-block:: bash
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ln -s /opt/conda/lib/stubs/libcuda.so /opt/conda/lib/stubs/libcuda.so.1
The logs of running ``./prepare.sh`` are listed below:
.. literalinclude:: ./log/log-preparation.txt
Training
--------
After preparing the data, we can start training with the following command
.. code-block:: bash
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/train.py
All of the training logs are given below:
.. hint::
It is running on CPU and it takes only 16 seconds for this run.
.. literalinclude:: ./log/log-train-2023-08-01-01-55-27
Decoding
--------
After training, we can decode the trained model with
.. code-block:: bash
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/decode.py
The decoding logs are given below:
.. code-block:: bash
2023-08-01 02:06:22,400 INFO [decode.py:263] Decoding started
2023-08-01 02:06:22,400 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d663.clean', 'torch-version': '1.13.0', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.9', 'icefall-git-branch': 'master', 'icefall-git-sha1': '375520d-clean', 'icefall-git-date': 'Fri Jul 28 07:43:08 2023', 'icefall-path': '/workspace/icefall', 'k2-path': '/opt/conda/lib/python3.9/site-packages/k2/__init__.py', 'lhotse-path': '/opt/conda/lib/python3.9/site-packages/lhotse/__init__.py', 'hostname': '60c947eac59c', 'IP address': '172.17.0.2'}}
2023-08-01 02:06:22,401 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-08-01 02:06:22,403 INFO [decode.py:273] device: cpu
2023-08-01 02:06:22,406 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
2023-08-01 02:06:22,424 INFO [asr_datamodule.py:218] About to get test cuts
2023-08-01 02:06:22,425 INFO [asr_datamodule.py:252] About to get test cuts
2023-08-01 02:06:22,504 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.
2023-08-01 02:06:22,687 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
2023-08-01 02:06:22,688 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
2023-08-01 02:06:22,690 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
2023-08-01 02:06:22,690 INFO [decode.py:316] Done!
Congratulations! You have finished successfully running `icefall`_ within a docker container.

View File

@ -21,9 +21,11 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
:caption: Contents:
installation/index
docker/index
faqs
model-export/index
.. toctree::
:maxdepth: 3
@ -34,3 +36,8 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
contributing/index
huggingface/index
.. toctree::
:maxdepth: 2
decoding-with-langugage-models/index

View File

@ -3,40 +3,28 @@
Installation
============
.. hint::
We also provide :ref:`icefall_docker` support, which has already setup
the environment for you.
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_.
.. hint::
We have a colab notebook guiding you step by step to setup the environment.
|yesno colab notebook|
.. |yesno colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing
`icefall`_ depends on `k2`_ and `lhotse`_.
We recommend that you use the following steps to install the dependencies.
- (0) Install CUDA toolkit and cuDNN
- (1) Install PyTorch and torchaudio
- (2) Install k2
- (3) Install lhotse
.. caution::
99% users who have issues about the installation are using conda.
.. caution::
99% users who have issues about the installation are using conda.
.. caution::
99% users who have issues about the installation are using conda.
.. hint::
We suggest that you use ``pip install`` to install PyTorch.
You can use the following command to create a virutal environment in Python:
.. code-block:: bash
python3 -m venv ./my_env
source ./my_env/bin/activate
- (1) Install `torch`_ and `torchaudio`_
- (2) Install `k2`_
- (3) Install `lhotse`_
.. caution::
@ -50,27 +38,20 @@ Please refer to
to install CUDA and cuDNN.
(1) Install PyTorch and torchaudio
----------------------------------
(1) Install torch and torchaudio
--------------------------------
Please refer `<https://pytorch.org/>`_ to install PyTorch
and torchaudio.
.. hint::
You can also go to `<https://download.pytorch.org/whl/torch_stable.html>`_
to download pre-compiled wheels and install them.
Please refer `<https://pytorch.org/>`_ to install `torch`_ and `torchaudio`_.
.. caution::
Please install torch and torchaudio at the same time.
(2) Install k2
--------------
Please refer to `<https://k2-fsa.github.io/k2/installation/index.html>`_
to install ``k2``.
to install `k2`_.
.. caution::
@ -78,21 +59,18 @@ to install ``k2``.
.. note::
We suggest that you install k2 from source by following
`<https://k2-fsa.github.io/k2/installation/from_source.html>`_
or
`<https://k2-fsa.github.io/k2/installation/for_developers.html>`_.
We suggest that you install k2 from pre-compiled wheels by following
`<https://k2-fsa.github.io/k2/installation/from_wheels.html>`_
.. hint::
Please always install the latest version of k2.
Please always install the latest version of `k2`_.
(3) Install lhotse
------------------
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
to install ``lhotse``.
to install `lhotse`_.
.. hint::
@ -100,17 +78,16 @@ to install ``lhotse``.
pip install git+https://github.com/lhotse-speech/lhotse
to install the latest version of lhotse.
to install the latest version of `lhotse`_.
(4) Download icefall
--------------------
``icefall`` is a collection of Python scripts; what you need is to download it
`icefall`_ is a collection of Python scripts; what you need is to download it
and set the environment variable ``PYTHONPATH`` to point to it.
Assume you want to place ``icefall`` in the folder ``/tmp``. The
following commands show you how to setup ``icefall``:
Assume you want to place `icefall`_ in the folder ``/tmp``. The
following commands show you how to setup `icefall`_:
.. code-block:: bash
@ -122,285 +99,334 @@ following commands show you how to setup ``icefall``:
.. HINT::
You can put several versions of ``icefall`` in the same virtual environment.
To switch among different versions of ``icefall``, just set ``PYTHONPATH``
You can put several versions of `icefall`_ in the same virtual environment.
To switch among different versions of `icefall`_, just set ``PYTHONPATH``
to point to the version you want.
Installation example
--------------------
The following shows an example about setting up the environment.
(1) Create a virtual environment
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
$ virtualenv -p python3.8 test-icefall
kuangfangjun:~$ virtualenv -p python3.8 test-icefall
created virtual environment CPython3.8.0.final.0-64 in 9422ms
creator CPython3Posix(dest=/star-fj/fangjun/test-icefall, clear=False, no_vcs_ignore=False, global=False)
seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/star-fj/fangjun/.local/share/virtualenv)
added seed packages: pip==22.3.1, setuptools==65.6.3, wheel==0.38.4
activators BashActivator,CShellActivator,FishActivator,NushellActivator,PowerShellActivator,PythonActivator
created virtual environment CPython3.8.6.final.0-64 in 1540ms
creator CPython3Posix(dest=/ceph-fj/fangjun/test-icefall, clear=False, no_vcs_ignore=False, global=False)
seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/fangjun/.local/share/v
irtualenv)
added seed packages: pip==21.1.3, setuptools==57.4.0, wheel==0.36.2
activators BashActivator,CShellActivator,FishActivator,PowerShellActivator,PythonActivator,XonshActivator
kuangfangjun:~$ source test-icefall/bin/activate
(test-icefall) kuangfangjun:~$
(2) Activate your virtual environment
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(2) Install CUDA toolkit and cuDNN
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
You need to determine the version of CUDA toolkit to install.
.. code-block:: bash
$ source test-icefall/bin/activate
(test-icefall) kuangfangjun:~$ nvidia-smi | head -n 4
(3) Install k2
Wed Jul 26 21:57:49 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
You can choose any CUDA version that is ``not`` greater than the version printed by ``nvidia-smi``.
In our case, we can choose any version ``<= 11.6``.
We will use ``CUDA 11.6`` in this example. Please follow
`<https://k2-fsa.github.io/k2/installation/cuda-cudnn.html#cuda-11-6>`_
to install CUDA toolkit and cuDNN if you have not done that before.
After installing CUDA toolkit, you can use the following command to verify it:
.. code-block:: bash
(test-icefall) kuangfangjun:~$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
(3) Install torch and torchaudio
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Since we have selected CUDA toolkit ``11.6``, we have to install a version of `torch`_
that is compiled against CUDA ``11.6``. We select ``torch 1.13.0+cu116`` in this
example.
After selecting the version of `torch`_ to install, we need to also install
a compatible version of `torchaudio`_, which is ``0.13.0+cu116`` in our case.
Please refer to `<https://pytorch.org/audio/stable/installation.html#compatibility-matrix>`_
to select an appropriate version of `torchaudio`_ to install if you use a different
version of `torch`_.
.. code-block:: bash
(test-icefall) kuangfangjun:~$ pip install torch==1.13.0+cu116 torchaudio==0.13.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.13.0+cu116
Downloading https://download.pytorch.org/whl/cu116/torch-1.13.0%2Bcu116-cp38-cp38-linux_x86_64.whl (1983.0 MB)
________________________________________ 2.0/2.0 GB 764.4 kB/s eta 0:00:00
Collecting torchaudio==0.13.0+cu116
Downloading https://download.pytorch.org/whl/cu116/torchaudio-0.13.0%2Bcu116-cp38-cp38-linux_x86_64.whl (4.2 MB)
________________________________________ 4.2/4.2 MB 1.3 MB/s eta 0:00:00
Requirement already satisfied: typing-extensions in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from torch==1.13.0+cu116) (4.7.1)
Installing collected packages: torch, torchaudio
Successfully installed torch-1.13.0+cu116 torchaudio-0.13.0+cu116
Verify that `torch`_ and `torchaudio`_ are successfully installed:
.. code-block:: bash
(test-icefall) kuangfangjun:~$ python3 -c "import torch; print(torch.__version__)"
1.13.0+cu116
(test-icefall) kuangfangjun:~$ python3 -c "import torchaudio; print(torchaudio.__version__)"
0.13.0+cu116
(4) Install k2
~~~~~~~~~~~~~~
We will install `k2`_ from pre-compiled wheels by following
`<https://k2-fsa.github.io/k2/installation/from_wheels.html>`_
.. code-block:: bash
$ pip install k2==1.4.dev20210822+cpu.torch1.9.0 -f https://k2-fsa.org/nightly/index.html
(test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html
Looking in links: https://k2-fsa.org/nightly/index.html
Collecting k2==1.4.dev20210822+cpu.torch1.9.0
Downloading https://k2-fsa.org/nightly/whl/k2-1.4.dev20210822%2Bcpu.torch1.9.0-cp38-cp38-linux_x86_64.whl (1.6 MB)
|________________________________| 1.6 MB 185 kB/s
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in links: https://k2-fsa.github.io/k2/cuda.html
Collecting k2==1.24.3.dev20230725+cuda11.6.torch1.13.0
Downloading https://huggingface.co/csukuangfj/k2/resolve/main/ubuntu-cuda/k2-1.24.3.dev20230725%2Bcuda11.6.torch1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (104.3 MB)
________________________________________ 104.3/104.3 MB 5.1 MB/s eta 0:00:00
Requirement already satisfied: torch==1.13.0 in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from k2==1.24.3.dev20230725+cuda11.6.torch1.13.0) (1.13.0+cu116)
Collecting graphviz
Downloading graphviz-0.17-py3-none-any.whl (18 kB)
Collecting torch==1.9.0
Using cached torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl (831.4 MB)
Collecting typing-extensions
Using cached typing_extensions-3.10.0.0-py3-none-any.whl (26 kB)
Installing collected packages: typing-extensions, torch, graphviz, k2
Successfully installed graphviz-0.17 k2-1.4.dev20210822+cpu.torch1.9.0 torch-1.9.0 typing-extensions-3.10.0.0
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/de/5e/fcbb22c68208d39edff467809d06c9d81d7d27426460ebc598e55130c1aa/graphviz-0.20.1-py3-none-any.whl (47 kB)
Requirement already satisfied: typing-extensions in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from torch==1.13.0->k2==1.24.3.dev20230725+cuda11.6.torch1.13.0) (4.7.1)
Installing collected packages: graphviz, k2
Successfully installed graphviz-0.20.1 k2-1.24.3.dev20230725+cuda11.6.torch1.13.0
.. WARNING::
.. hint::
We choose to install a CPU version of k2 for testing. You would probably want to install
a CUDA version of k2.
Please refer to `<https://k2-fsa.github.io/k2/cuda.html>`_ for the available
pre-compiled wheels about `k2`_.
Verify that `k2`_ has been installed successfully:
(4) Install lhotse
.. code-block:: bash
(test-icefall) kuangfangjun:~$ python3 -m k2.version
Collecting environment information...
k2 version: 1.24.3
Build type: Release
Git SHA1: 4c05309499a08454997adf500b56dcc629e35ae5
Git date: Tue Jul 25 16:23:36 2023
Cuda used to build k2: 11.6
cuDNN used to build k2: 8.3.2
Python version used to build k2: 3.8
OS used to build k2: CentOS Linux release 7.9.2009 (Core)
CMake version: 3.27.0
GCC version: 9.3.1
CMAKE_CUDA_FLAGS: -Wno-deprecated-gpu-targets -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_50,code=sm_50 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_61,code=sm_61 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_70,code=sm_70 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_80,code=sm_80 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_86,code=sm_86 -DONNX_NAMESPACE=onnx_c2 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_86,code=compute_86 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall --compiler-options -Wno-strict-overflow --compiler-options -Wno-unknown-pragmas
CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable -Wno-strict-overflow
PyTorch version used to build k2: 1.13.0+cu116
PyTorch is using Cuda: 11.6
NVTX enabled: True
With CUDA: True
Disable debug: True
Sync kernels : False
Disable checks: False
Max cpu memory allocate: 214748364800 bytes (or 200.0 GB)
k2 abort: False
__file__: /star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/version/version.py
_k2.__file__: /star-fj/fangjun/test-icefall/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so
(5) Install lhotse
~~~~~~~~~~~~~~~~~~
.. code-block::
.. code-block:: bash
$ pip install git+https://github.com/lhotse-speech/lhotse
(test-icefall) kuangfangjun:~$ pip install git+https://github.com/lhotse-speech/lhotse
Collecting git+https://github.com/lhotse-speech/lhotse
Cloning https://github.com/lhotse-speech/lhotse to /tmp/pip-req-build-7b1b76ge
Running command git clone -q https://github.com/lhotse-speech/lhotse /tmp/pip-req-build-7b1b76ge
Collecting audioread>=2.1.9
Using cached audioread-2.1.9-py3-none-any.whl
Collecting SoundFile>=0.10
Using cached SoundFile-0.10.3.post1-py2.py3-none-any.whl (21 kB)
Collecting click>=7.1.1
Using cached click-8.0.1-py3-none-any.whl (97 kB)
Cloning https://github.com/lhotse-speech/lhotse to /tmp/pip-req-build-vq12fd5i
Running command git clone --filter=blob:none --quiet https://github.com/lhotse-speech/lhotse /tmp/pip-req-build-vq12fd5i
Resolved https://github.com/lhotse-speech/lhotse to commit 7640d663469b22cd0b36f3246ee9b849cd25e3b7
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
Collecting cytoolz>=0.10.1
Using cached cytoolz-0.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
Collecting dataclasses
Using cached dataclasses-0.6-py3-none-any.whl (14 kB)
Collecting h5py>=2.10.0
Downloading h5py-3.4.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.5 MB)
|________________________________| 4.5 MB 684 kB/s
Collecting intervaltree>=3.1.0
Using cached intervaltree-3.1.0-py2.py3-none-any.whl
Collecting lilcom>=1.1.0
Using cached lilcom-1.1.1-cp38-cp38-linux_x86_64.whl
Collecting numpy>=1.18.1
Using cached numpy-1.21.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.8 MB)
Collecting packaging
Using cached packaging-21.0-py3-none-any.whl (40 kB)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1e/3b/a7828d575aa17fb7acaf1ced49a3655aa36dad7e16eb7e6a2e4df0dda76f/cytoolz-0.12.2-cp38-cp38-
manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
________________________________________ 2.0/2.0 MB 33.2 MB/s eta 0:00:00
Collecting pyyaml>=5.3.1
Using cached PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl (662 kB)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c8/6b/6600ac24725c7388255b2f5add93f91e58a5d7efaf4af244fdbcc11a541b/PyYAML-6.0.1-cp38-cp38-ma
nylinux_2_17_x86_64.manylinux2014_x86_64.whl (736 kB)
________________________________________ 736.6/736.6 kB 38.6 MB/s eta 0:00:00
Collecting dataclasses
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/26/2f/1095cdc2868052dd1e64520f7c0d5c8c550ad297e944e641dbf1ffbb9a5d/dataclasses-0.6-py3-none-
any.whl (14 kB)
Requirement already satisfied: torchaudio in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (0.13.0+cu116)
Collecting lilcom>=1.1.0
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a8/65/df0a69c52bd085ca1ad4e5c4c1a5c680e25f9477d8e49316c4ff1e5084a4/lilcom-1.7-cp38-cp38-many
linux_2_17_x86_64.manylinux2014_x86_64.whl (87 kB)
________________________________________ 87.1/87.1 kB 8.7 MB/s eta 0:00:00
Collecting tqdm
Downloading tqdm-4.62.1-py2.py3-none-any.whl (76 kB)
|________________________________| 76 kB 2.7 MB/s
Collecting torchaudio==0.9.0
Downloading torchaudio-0.9.0-cp38-cp38-manylinux1_x86_64.whl (1.9 MB)
|________________________________| 1.9 MB 73.1 MB/s
Requirement already satisfied: torch==1.9.0 in ./test-icefall/lib/python3.8/site-packages (from torchaudio==0.9.0->lhotse===0.8.0.dev
-2a1410b-clean) (1.9.0)
Requirement already satisfied: typing-extensions in ./test-icefall/lib/python3.8/site-packages (from torch==1.9.0->torchaudio==0.9.0-
>lhotse===0.8.0.dev-2a1410b-clean) (3.10.0.0)
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/e6/02/a2cff6306177ae6bc73bc0665065de51dfb3b9db7373e122e2735faf0d97/tqdm-4.65.0-py3-none-any
.whl (77 kB)
Requirement already satisfied: numpy>=1.18.1 in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (1.24.4)
Collecting audioread>=2.1.9
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/5d/cb/82a002441902dccbe427406785db07af10182245ee639ea9f4d92907c923/audioread-3.0.0.tar.gz (
377 kB)
Preparing metadata (setup.py) ... done
Collecting tabulate>=0.8.1
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-
any.whl (35 kB)
Collecting click>=7.1.1
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1a/70/e63223f8116931d365993d4a6b7ef653a4d920b41d03de7c59499962821f/click-8.1.6-py3-none-any.
whl (97 kB)
________________________________________ 97.9/97.9 kB 8.4 MB/s eta 0:00:00
Collecting packaging
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/ab/c3/57f0601a2d4fe15de7a553c00adbc901425661bf048f2a22dfc500caf121/packaging-23.1-py3-none-
any.whl (48 kB)
Collecting intervaltree>=3.1.0
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/50/fb/396d568039d21344639db96d940d40eb62befe704ef849b27949ded5c3bb/intervaltree-3.1.0.tar.gz
(32 kB)
Preparing metadata (setup.py) ... done
Requirement already satisfied: torch in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (1.13.0+cu116)
Collecting SoundFile>=0.10
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ad/bd/0602167a213d9184fc688b1086dc6d374b7ae8c33eccf169f9b50ce6568c/soundfile-0.12.1-py2.py3-
none-manylinux_2_17_x86_64.whl (1.3 MB)
________________________________________ 1.3/1.3 MB 46.5 MB/s eta 0:00:00
Collecting toolz>=0.8.0
Using cached toolz-0.11.1-py3-none-any.whl (55 kB)
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/7f/5c/922a3508f5bda2892be3df86c74f9cf1e01217c2b1f8a0ac4841d903e3e9/toolz-0.12.0-py3-none-any.whl (55 kB)
Collecting sortedcontainers<3.0,>=2.0
Using cached sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Collecting cffi>=1.0
Using cached cffi-1.14.6-cp38-cp38-manylinux1_x86_64.whl (411 kB)
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/b7/8b/06f30caa03b5b3ac006de4f93478dbd0239e2a16566d81a106c322dc4f79/cffi-1.15.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (442 kB)
Requirement already satisfied: typing-extensions in ./test-icefall/lib/python3.8/site-packages (from torch->lhotse==1.16.0.dev0+git.7640d66.clean) (4.7.1)
Collecting pycparser
Using cached pycparser-2.20-py2.py3-none-any.whl (112 kB)
Collecting pyparsing>=2.0.2
Using cached pyparsing-2.4.7-py2.py3-none-any.whl (67 kB)
Building wheels for collected packages: lhotse
Building wheel for lhotse (setup.py) ... done
Created wheel for lhotse: filename=lhotse-0.8.0.dev_2a1410b_clean-py3-none-any.whl size=342242 sha256=f683444afa4dc0881133206b4646a
9d0f774224cc84000f55d0a67f6e4a37997
Stored in directory: /tmp/pip-ephem-wheel-cache-ftu0qysz/wheels/7f/7a/8e/a0bf241336e2e3cb573e1e21e5600952d49f5162454f2e612f
WARNING: Built wheel for lhotse is invalid: Metadata 1.2 mandates PEP 440 version, but '0.8.0.dev-2a1410b-clean' is not
Failed to build lhotse
Installing collected packages: pycparser, toolz, sortedcontainers, pyparsing, numpy, cffi, tqdm, torchaudio, SoundFile, pyyaml, packa
ging, lilcom, intervaltree, h5py, dataclasses, cytoolz, click, audioread, lhotse
Running setup.py install for lhotse ... done
DEPRECATION: lhotse was installed using the legacy 'setup.py install' method, because a wheel could not be built for it. A possible
replacement is to fix the wheel build issue reported above. You can find discussion regarding this at https://github.com/pypa/pip/is
sues/8368.
Successfully installed SoundFile-0.10.3.post1 audioread-2.1.9 cffi-1.14.6 click-8.0.1 cytoolz-0.11.0 dataclasses-0.6 h5py-3.4.0 inter
valtree-3.1.0 lhotse-0.8.0.dev-2a1410b-clean lilcom-1.1.1 numpy-1.21.2 packaging-21.0 pycparser-2.20 pyparsing-2.4.7 pyyaml-5.4.1 sor
tedcontainers-2.4.0 toolz-0.11.1 torchaudio-0.9.0 tqdm-4.62.1
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/62/d5/5f610ebe421e85889f2e55e33b7f9a6795bd982198517d912eb1c76e1a53/pycparser-2.21-py2.py3-none-any.whl (118 kB)
Building wheels for collected packages: lhotse, audioread, intervaltree
Building wheel for lhotse (pyproject.toml) ... done
Created wheel for lhotse: filename=lhotse-1.16.0.dev0+git.7640d66.clean-py3-none-any.whl size=687627 sha256=cbf0a4d2d0b639b33b91637a4175bc251d6a021a069644ecb1a9f2b3a83d072a
Stored in directory: /tmp/pip-ephem-wheel-cache-wwtk90_m/wheels/7f/7a/8e/a0bf241336e2e3cb573e1e21e5600952d49f5162454f2e612f
Building wheel for audioread (setup.py) ... done
Created wheel for audioread: filename=audioread-3.0.0-py3-none-any.whl size=23704 sha256=5e2d3537c96ce9cf0f645a654c671163707bf8cb8d9e358d0e2b0939a85ff4c2
Stored in directory: /star-fj/fangjun/.cache/pip/wheels/e2/c3/9c/f19ae5a03f8862d9f0776b0c0570f1fdd60a119d90954e3f39
Building wheel for intervaltree (setup.py) ... done
Created wheel for intervaltree: filename=intervaltree-3.1.0-py2.py3-none-any.whl size=26098 sha256=2604170976cfffe0d2f678cb1a6e5b525f561cd50babe53d631a186734fec9f9
Stored in directory: /star-fj/fangjun/.cache/pip/wheels/f3/ed/2b/c179ebfad4e15452d6baef59737f27beb9bfb442e0620f7271
Successfully built lhotse audioread intervaltree
Installing collected packages: sortedcontainers, dataclasses, tqdm, toolz, tabulate, pyyaml, pycparser, packaging, lilcom, intervaltree, click, audioread, cytoolz, cffi, SoundFile, lhotse
Successfully installed SoundFile-0.12.1 audioread-3.0.0 cffi-1.15.1 click-8.1.6 cytoolz-0.12.2 dataclasses-0.6 intervaltree-3.1.0 lhotse-1.16.0.dev0+git.7640d66.clean lilcom-1.7 packaging-23.1 pycparser-2.21 pyyaml-6.0.1 sortedcontainers-2.4.0 tabulate-0.9.0 toolz-0.12.0 tqdm-4.65.0
(5) Download icefall
Verify that `lhotse`_ has been installed successfully:
.. code-block:: bash
(test-icefall) kuangfangjun:~$ python3 -c "import lhotse; print(lhotse.__version__)"
1.16.0.dev+git.7640d66.clean
(6) Download icefall
~~~~~~~~~~~~~~~~~~~~
.. code-block::
.. code-block:: bash
$ cd /tmp
$ git clone https://github.com/k2-fsa/icefall
(test-icefall) kuangfangjun:~$ cd /tmp/
(test-icefall) kuangfangjun:tmp$ git clone https://github.com/k2-fsa/icefall
Cloning into 'icefall'...
remote: Enumerating objects: 500, done.
remote: Counting objects: 100% (500/500), done.
remote: Compressing objects: 100% (308/308), done.
remote: Total 500 (delta 263), reused 307 (delta 102), pack-reused 0
Receiving objects: 100% (500/500), 172.49 KiB | 385.00 KiB/s, done.
Resolving deltas: 100% (263/263), done.
remote: Enumerating objects: 12942, done.
remote: Counting objects: 100% (67/67), done.
remote: Compressing objects: 100% (56/56), done.
remote: Total 12942 (delta 17), reused 35 (delta 6), pack-reused 12875
Receiving objects: 100% (12942/12942), 14.77 MiB | 9.29 MiB/s, done.
Resolving deltas: 100% (8835/8835), done.
$ cd icefall
$ pip install -r requirements.txt
Collecting kaldilm
Downloading kaldilm-1.8.tar.gz (48 kB)
|________________________________| 48 kB 574 kB/s
Collecting kaldialign
Using cached kaldialign-0.2-cp38-cp38-linux_x86_64.whl
Collecting sentencepiece>=0.1.96
Using cached sentencepiece-0.1.96-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
Collecting tensorboard
Using cached tensorboard-2.6.0-py3-none-any.whl (5.6 MB)
Requirement already satisfied: setuptools>=41.0.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r
requirements.txt (line 4)) (57.4.0)
Collecting absl-py>=0.4
Using cached absl_py-0.13.0-py3-none-any.whl (132 kB)
Collecting google-auth-oauthlib<0.5,>=0.4.1
Using cached google_auth_oauthlib-0.4.5-py2.py3-none-any.whl (18 kB)
Collecting grpcio>=1.24.3
Using cached grpcio-1.39.0-cp38-cp38-manylinux2014_x86_64.whl (4.3 MB)
Requirement already satisfied: wheel>=0.26 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r require
ments.txt (line 4)) (0.36.2)
Requirement already satisfied: numpy>=1.12.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r requi
rements.txt (line 4)) (1.21.2)
Collecting protobuf>=3.6.0
Using cached protobuf-3.17.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
Collecting werkzeug>=0.11.15
Using cached Werkzeug-2.0.1-py3-none-any.whl (288 kB)
Collecting tensorboard-data-server<0.7.0,>=0.6.0
Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB)
Collecting google-auth<2,>=1.6.3
Downloading google_auth-1.35.0-py2.py3-none-any.whl (152 kB)
|________________________________| 152 kB 1.4 MB/s
Collecting requests<3,>=2.21.0
Using cached requests-2.26.0-py2.py3-none-any.whl (62 kB)
Collecting tensorboard-plugin-wit>=1.6.0
Using cached tensorboard_plugin_wit-1.8.0-py3-none-any.whl (781 kB)
Collecting markdown>=2.6.8
Using cached Markdown-3.3.4-py3-none-any.whl (97 kB)
Collecting six
Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting cachetools<5.0,>=2.0.0
Using cached cachetools-4.2.2-py3-none-any.whl (11 kB)
Collecting rsa<5,>=3.1.4
Using cached rsa-4.7.2-py3-none-any.whl (34 kB)
Collecting pyasn1-modules>=0.2.1
Using cached pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB)
Collecting requests-oauthlib>=0.7.0
Using cached requests_oauthlib-1.3.0-py2.py3-none-any.whl (23 kB)
Collecting pyasn1<0.5.0,>=0.4.6
Using cached pyasn1-0.4.8-py2.py3-none-any.whl (77 kB)
Collecting urllib3<1.27,>=1.21.1
Using cached urllib3-1.26.6-py2.py3-none-any.whl (138 kB)
Collecting certifi>=2017.4.17
Using cached certifi-2021.5.30-py2.py3-none-any.whl (145 kB)
Collecting charset-normalizer~=2.0.0
Using cached charset_normalizer-2.0.4-py3-none-any.whl (36 kB)
Collecting idna<4,>=2.5
Using cached idna-3.2-py3-none-any.whl (59 kB)
Collecting oauthlib>=3.0.0
Using cached oauthlib-3.1.1-py2.py3-none-any.whl (146 kB)
Building wheels for collected packages: kaldilm
Building wheel for kaldilm (setup.py) ... done
Created wheel for kaldilm: filename=kaldilm-1.8-cp38-cp38-linux_x86_64.whl size=897233 sha256=eccb906cafcd45bf9a7e1a1718e4534254bfb
f4c0d0cbc66eee6c88d68a63862
Stored in directory: /root/fangjun/.cache/pip/wheels/85/7d/63/f2dd586369b8797cb36d213bf3a84a789eeb92db93d2e723c9
Successfully built kaldilm
Installing collected packages: urllib3, pyasn1, idna, charset-normalizer, certifi, six, rsa, requests, pyasn1-modules, oauthlib, cach
etools, requests-oauthlib, google-auth, werkzeug, tensorboard-plugin-wit, tensorboard-data-server, protobuf, markdown, grpcio, google
-auth-oauthlib, absl-py, tensorboard, sentencepiece, kaldilm, kaldialign
Successfully installed absl-py-0.13.0 cachetools-4.2.2 certifi-2021.5.30 charset-normalizer-2.0.4 google-auth-1.35.0 google-auth-oaut
hlib-0.4.5 grpcio-1.39.0 idna-3.2 kaldialign-0.2 kaldilm-1.8 markdown-3.3.4 oauthlib-3.1.1 protobuf-3.17.3 pyasn1-0.4.8 pyasn1-module
s-0.2.8 requests-2.26.0 requests-oauthlib-1.3.0 rsa-4.7.2 sentencepiece-0.1.96 six-1.16.0 tensorboard-2.6.0 tensorboard-data-server-0
.6.1 tensorboard-plugin-wit-1.8.0 urllib3-1.26.6 werkzeug-2.0.1
(test-icefall) kuangfangjun:tmp$ cd icefall/
(test-icefall) kuangfangjun:icefall$ pip install -r ./requirements.txt
Test Your Installation
----------------------
To test that your installation is successful, let us run
the `yesno recipe <https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR>`_
on CPU.
on ``CPU``.
Data preparation
~~~~~~~~~~~~~~~~
.. code-block:: bash
$ export PYTHONPATH=/tmp/icefall:$PYTHONPATH
$ cd /tmp/icefall
$ cd egs/yesno/ASR
$ ./prepare.sh
(test-icefall) kuangfangjun:icefall$ export PYTHONPATH=/tmp/icefall:$PYTHONPATH
(test-icefall) kuangfangjun:icefall$ cd /tmp/icefall
(test-icefall) kuangfangjun:icefall$ cd egs/yesno/ASR
(test-icefall) kuangfangjun:ASR$ ./prepare.sh
The log of running ``./prepare.sh`` is:
.. code-block::
2023-05-12 17:55:21 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
2023-05-12 17:55:21 (prepare.sh:30:main) Stage 0: Download data
/tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|_______________________________________________________________| 4.70M/4.70M [06:54<00:00, 11.4kB/s]
2023-05-12 18:02:19 (prepare.sh:39:main) Stage 1: Prepare yesno manifest
2023-05-12 18:02:21 (prepare.sh:45:main) Stage 2: Compute fbank for yesno
2023-05-12 18:02:23,199 INFO [compute_fbank_yesno.py:65] Processing train
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:00<00:00, 212.60it/s]
2023-05-12 18:02:23,640 INFO [compute_fbank_yesno.py:65] Processing test
Extracting and storing features: 100%|_______________________________________________________________| 30/30 [00:00<00:00, 304.53it/s]
2023-05-12 18:02:24 (prepare.sh:51:main) Stage 3: Prepare lang
2023-05-12 18:02:26 (prepare.sh:66:main) Stage 4: Prepare G
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79
[I] Reading \data\ section.
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140
[I] Reading \1-grams: section.
2023-05-12 18:02:26 (prepare.sh:92:main) Stage 5: Compile HLG
2023-05-12 18:02:28,581 INFO [compile_hlg.py:124] Processing data/lang_phone
2023-05-12 18:02:28,582 INFO [lexicon.py:171] Converting L.pt to Linv.pt
2023-05-12 18:02:28,609 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
2023-05-12 18:02:28,610 INFO [compile_hlg.py:52] Loading G.fst.txt
2023-05-12 18:02:28,611 INFO [compile_hlg.py:62] Intersecting L and G
2023-05-12 18:02:28,613 INFO [compile_hlg.py:64] LG shape: (4, None)
2023-05-12 18:02:28,613 INFO [compile_hlg.py:66] Connecting LG
2023-05-12 18:02:28,614 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
2023-05-12 18:02:28,614 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
2023-05-12 18:02:28,614 INFO [compile_hlg.py:71] Determinizing LG
2023-05-12 18:02:28,615 INFO [compile_hlg.py:74] <class '_k2.ragged.RaggedTensor'>
2023-05-12 18:02:28,615 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
2023-05-12 18:02:28,615 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
2023-05-12 18:02:28,616 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None)
2023-05-12 18:02:28,617 INFO [compile_hlg.py:96] Arc sorting LG
2023-05-12 18:02:28,617 INFO [compile_hlg.py:99] Composing H and LG
2023-05-12 18:02:28,619 INFO [compile_hlg.py:106] Connecting LG
2023-05-12 18:02:28,619 INFO [compile_hlg.py:109] Arc sorting LG
2023-05-12 18:02:28,619 INFO [compile_hlg.py:111] HLG.shape: (8, None)
2023-05-12 18:02:28,619 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone
2023-07-27 12:41:39 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
2023-07-27 12:41:39 (prepare.sh:30:main) Stage 0: Download data
/tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|___________________________________________________| 4.70M/4.70M [00:00<00:00, 11.1MB/s]
2023-07-27 12:41:46 (prepare.sh:39:main) Stage 1: Prepare yesno manifest
2023-07-27 12:41:50 (prepare.sh:45:main) Stage 2: Compute fbank for yesno
2023-07-27 12:41:55,718 INFO [compute_fbank_yesno.py:65] Processing train
Extracting and storing features: 100%|_______________________________________________________________________________| 90/90 [00:01<00:00, 87.82it/s]
2023-07-27 12:41:56,778 INFO [compute_fbank_yesno.py:65] Processing test
Extracting and storing features: 100%|______________________________________________________________________________| 30/30 [00:00<00:00, 256.92it/s]
2023-07-27 12:41:57 (prepare.sh:51:main) Stage 3: Prepare lang
2023-07-27 12:42:02 (prepare.sh:66:main) Stage 4: Prepare G
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79
[I] Reading \data\ section.
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140
[I] Reading \1-grams: section.
2023-07-27 12:42:02 (prepare.sh:92:main) Stage 5: Compile HLG
2023-07-27 12:42:07,275 INFO [compile_hlg.py:124] Processing data/lang_phone
2023-07-27 12:42:07,276 INFO [lexicon.py:171] Converting L.pt to Linv.pt
2023-07-27 12:42:07,309 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
2023-07-27 12:42:07,310 INFO [compile_hlg.py:52] Loading G.fst.txt
2023-07-27 12:42:07,314 INFO [compile_hlg.py:62] Intersecting L and G
2023-07-27 12:42:07,323 INFO [compile_hlg.py:64] LG shape: (4, None)
2023-07-27 12:42:07,323 INFO [compile_hlg.py:66] Connecting LG
2023-07-27 12:42:07,323 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
2023-07-27 12:42:07,323 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
2023-07-27 12:42:07,323 INFO [compile_hlg.py:71] Determinizing LG
2023-07-27 12:42:07,341 INFO [compile_hlg.py:74] <class '_k2.ragged.RaggedTensor'>
2023-07-27 12:42:07,341 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
2023-07-27 12:42:07,341 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
2023-07-27 12:42:07,354 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None)
2023-07-27 12:42:07,445 INFO [compile_hlg.py:96] Arc sorting LG
2023-07-27 12:42:07,445 INFO [compile_hlg.py:99] Composing H and LG
2023-07-27 12:42:07,446 INFO [compile_hlg.py:106] Connecting LG
2023-07-27 12:42:07,446 INFO [compile_hlg.py:109] Arc sorting LG
2023-07-27 12:42:07,447 INFO [compile_hlg.py:111] HLG.shape: (8, None)
2023-07-27 12:42:07,447 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone
Training
~~~~~~~~
@ -409,12 +435,13 @@ Now let us run the training part:
.. code-block::
$ export CUDA_VISIBLE_DEVICES=""
$ ./tdnn/train.py
(test-icefall) kuangfangjun:ASR$ export CUDA_VISIBLE_DEVICES=""
(test-icefall) kuangfangjun:ASR$ ./tdnn/train.py
.. CAUTION::
We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
We use ``export CUDA_VISIBLE_DEVICES=""`` so that `icefall`_ uses CPU
even if there are GPUs available.
.. hint::
@ -432,53 +459,52 @@ The training log is given below:
.. code-block::
2023-05-12 18:04:59,759 INFO [train.py:481] Training started
2023-05-12 18:04:59,759 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0,
'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10,
'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0,
'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2,
'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
'k2-path': 'tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
'lhotse-path': 'tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
2023-05-12 18:04:59,761 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-05-12 18:04:59,764 INFO [train.py:495] device: cpu
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:146] About to get train cuts
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:244] About to get train cuts
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:149] About to create train dataset
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:199] Using SingleCutSampler.
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:205] About to create train dataloader
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:218] About to get test cuts
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:252] About to get test cuts
2023-05-12 18:04:59,986 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4
2023-05-12 18:05:00,352 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames. ], batch size: 4
2023-05-12 18:05:00,691 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames.
2023-05-12 18:05:00,996 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5
2023-05-12 18:05:01,217 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames.
2023-05-12 18:05:01,251 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt
2023-05-12 18:05:01,389 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ], batch size: 4
2023-05-12 18:05:01,637 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames. ], batch size: 4
2023-05-12 18:05:01,859 INFO [train.py:444] Epoch 1, validation loss=0.1629, over 18067.00 frames.
2023-05-12 18:05:02,094 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.0767, over 2695.00 frames. ], tot_loss[loss=0.118, over 34971.47 frames. ], batch size: 5
2023-05-12 18:05:02,350 INFO [train.py:444] Epoch 1, validation loss=0.06778, over 18067.00 frames.
2023-05-12 18:05:02,395 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt
2023-07-27 12:50:51,936 INFO [train.py:481] Training started
2023-07-27 12:50:51,936 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d66.clean', 'torch-version': '1.13.0+cu116', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '3fb0a43-clean', 'icefall-git-date': 'Thu Jul 27 12:36:05 2023', 'icefall-path': '/tmp/icefall', 'k2-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-1-1220091118-57c4d55446-sph26', 'IP address': '10.177.77.20'}}
2023-07-27 12:50:51,941 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-07-27 12:50:51,949 INFO [train.py:495] device: cpu
2023-07-27 12:50:51,965 INFO [asr_datamodule.py:146] About to get train cuts
2023-07-27 12:50:51,965 INFO [asr_datamodule.py:244] About to get train cuts
2023-07-27 12:50:51,967 INFO [asr_datamodule.py:149] About to create train dataset
2023-07-27 12:50:51,967 INFO [asr_datamodule.py:199] Using SingleCutSampler.
2023-07-27 12:50:51,967 INFO [asr_datamodule.py:205] About to create train dataloader
2023-07-27 12:50:51,968 INFO [asr_datamodule.py:218] About to get test cuts
2023-07-27 12:50:51,968 INFO [asr_datamodule.py:252] About to get test cuts
2023-07-27 12:50:52,565 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4
2023-07-27 12:50:53,681 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames.], batch size: 4
2023-07-27 12:50:54,167 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames.
2023-07-27 12:50:55,011 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5
2023-07-27 12:50:55,331 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames.
2023-07-27 12:50:55,368 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt
2023-07-27 12:50:55,633 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ],
batch size: 4
2023-07-27 12:50:56,242 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames.], batch size: 4
2023-07-27 12:50:56,522 INFO [train.py:444] Epoch 1, validation loss=0.1627, over 18067.00 frames.
2023-07-27 12:50:57,209 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.07055, over 2695.00 frames. ], tot_loss[loss=0.1175, over 34971.47 frames.], batch size: 5
2023-07-27 12:50:57,600 INFO [train.py:444] Epoch 1, validation loss=0.07091, over 18067.00 frames.
2023-07-27 12:50:57,640 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt
2023-07-27 12:50:57,847 INFO [train.py:422] Epoch 2, batch 0, loss[loss=0.07731, over 2436.00 frames. ], tot_loss[loss=0.07731, over 2436.00 frames.], batch size: 4
2023-07-27 12:50:58,427 INFO [train.py:422] Epoch 2, batch 10, loss[loss=0.04391, over 2828.00 frames. ], tot_loss[loss=0.05341, over 22192.90 frames. ], batch size: 4
2023-07-27 12:50:58,884 INFO [train.py:444] Epoch 2, validation loss=0.04384, over 18067.00 frames.
2023-07-27 12:50:59,387 INFO [train.py:422] Epoch 2, batch 20, loss[loss=0.03458, over 2695.00 frames. ], tot_loss[loss=0.04616, over 34971.47 frames. ], batch size: 5
2023-07-27 12:50:59,707 INFO [train.py:444] Epoch 2, validation loss=0.03379, over 18067.00 frames.
2023-07-27 12:50:59,758 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-2.pt
... ...
... ...
2023-05-12 18:05:14,789 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01056, over 2436.00 frames. ], tot_loss[loss=0.01056, over 2436.00 frames. ], batch size: 4
2023-05-12 18:05:15,016 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009022, over 2828.00 frames. ], tot_loss[loss=0.009985, over 22192.90 frames. ], batch size: 4
2023-05-12 18:05:15,271 INFO [train.py:444] Epoch 13, validation loss=0.01088, over 18067.00 frames.
2023-05-12 18:05:15,497 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01174, over 2695.00 frames. ], tot_loss[loss=0.01077, over 34971.47 frames. ], batch size: 5
2023-05-12 18:05:15,747 INFO [train.py:444] Epoch 13, validation loss=0.01087, over 18067.00 frames.
2023-05-12 18:05:15,783 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt
2023-05-12 18:05:15,921 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01045, over 2436.00 frames. ], tot_loss[loss=0.01045, over 2436.00 frames. ], batch size: 4
2023-05-12 18:05:16,146 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008957, over 2828.00 frames. ], tot_loss[loss=0.009903, over 22192.90 frames. ], batch size: 4
2023-05-12 18:05:16,374 INFO [train.py:444] Epoch 14, validation loss=0.01092, over 18067.00 frames.
2023-05-12 18:05:16,598 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01065, over 34971.47 frames. ], batch size: 5
2023-05-12 18:05:16,824 INFO [train.py:444] Epoch 14, validation loss=0.01077, over 18067.00 frames.
2023-05-12 18:05:16,862 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt
2023-05-12 18:05:16,865 INFO [train.py:555] Done!
2023-07-27 12:51:23,433 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01054, over 2436.00 frames. ], tot_loss[loss=0.01054, over 2436.00 frames. ], batch size: 4
2023-07-27 12:51:23,980 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009014, over 2828.00 frames. ], tot_loss[loss=0.009974, over 22192.90 frames. ], batch size: 4
2023-07-27 12:51:24,489 INFO [train.py:444] Epoch 13, validation loss=0.01085, over 18067.00 frames.
2023-07-27 12:51:25,258 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01172, over 2695.00 frames. ], tot_loss[loss=0.01055, over 34971.47 frames. ], batch size: 5
2023-07-27 12:51:25,621 INFO [train.py:444] Epoch 13, validation loss=0.01074, over 18067.00 frames.
2023-07-27 12:51:25,699 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt
2023-07-27 12:51:25,866 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01044, over 2436.00 frames. ], tot_loss[loss=0.01044, over 2436.00 frames. ], batch size: 4
2023-07-27 12:51:26,844 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008942, over 2828.00 frames. ], tot_loss[loss=0.01, over 22192.90 frames. ], batch size: 4
2023-07-27 12:51:27,221 INFO [train.py:444] Epoch 14, validation loss=0.01082, over 18067.00 frames.
2023-07-27 12:51:27,970 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01054, over 34971.47 frames. ], batch size: 5
2023-07-27 12:51:28,247 INFO [train.py:444] Epoch 14, validation loss=0.01073, over 18067.00 frames.
2023-07-27 12:51:28,323 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt
2023-07-27 12:51:28,326 INFO [train.py:555] Done!
Decoding
~~~~~~~~
@ -487,42 +513,32 @@ Let us use the trained model to decode the test set:
.. code-block::
$ ./tdnn/decode.py
(test-icefall) kuangfangjun:ASR$ ./tdnn/decode.py
The decoding log is:
2023-07-27 12:55:12,840 INFO [decode.py:263] Decoding started
2023-07-27 12:55:12,840 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d66.clean', 'torch-version': '1.13.0+cu116', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '3fb0a43-clean', 'icefall-git-date': 'Thu Jul 27 12:36:05 2023', 'icefall-path': '/tmp/icefall', 'k2-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-1-1220091118-57c4d55446-sph26', 'IP address': '10.177.77.20'}}
2023-07-27 12:55:12,841 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-07-27 12:55:12,855 INFO [decode.py:273] device: cpu
2023-07-27 12:55:12,868 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
2023-07-27 12:55:12,882 INFO [asr_datamodule.py:218] About to get test cuts
2023-07-27 12:55:12,883 INFO [asr_datamodule.py:252] About to get test cuts
2023-07-27 12:55:13,157 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
2023-07-27 12:55:13,701 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
2023-07-27 12:55:13,702 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
2023-07-27 12:55:13,704 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
2023-07-27 12:55:13,704 INFO [decode.py:316] Done!
.. code-block::
2023-05-12 18:08:30,482 INFO [decode.py:263] Decoding started
2023-05-12 18:08:30,483 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23,
'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'),
'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True,
'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
'k2-path': '/tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
'lhotse-path': '/tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
2023-05-12 18:08:30,483 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
2023-05-12 18:08:30,487 INFO [decode.py:273] device: cpu
2023-05-12 18:08:30,513 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:218] About to get test cuts
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:252] About to get test cuts
2023-05-12 18:08:30,675 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
2023-05-12 18:08:30,923 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
2023-05-12 18:08:30,924 INFO [utils.py:558] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
2023-05-12 18:08:30,925 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
2023-05-12 18:08:30,925 INFO [decode.py:316] Done!
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
**Congratulations!** You have successfully setup the environment and have run the first recipe in `icefall`_.
Have fun with ``icefall``!
YouTube Video
-------------
We provide the following YouTube video showing how to install ``icefall``.
We provide the following YouTube video showing how to install `icefall`_.
It also shows how to debug various problems that you may encounter while
using ``icefall``.
using `icefall`_.
.. note::

View File

@ -1,7 +1,7 @@
Distillation with HuBERT
========================
This tutorial shows you how to perform knowledge distillation in `icefall`_
This tutorial shows you how to perform knowledge distillation in `icefall <https://github.com/k2-fsa/icefall>`_
with the `LibriSpeech`_ dataset. The distillation method
used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD).
Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation <https://arxiv.org/abs/2211.00508>`_
@ -13,7 +13,7 @@ for more details about MVQ-KD.
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_.
Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes
with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
.. note::
@ -217,7 +217,7 @@ the following command.
--exp-dir $exp_dir \
--enable-distillation True
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`_.
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`__.
That's all! Feel free to experiment with your own setups and report your results.
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`__.

View File

@ -8,10 +8,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
.. Note::
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_,
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_,
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_,
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_,
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
We will take pruned_transducer_stateless4 as an example in this tutorial.
.. HINT::
@ -237,7 +237,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
.. NOTE::
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
@ -529,13 +529,13 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following links:
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`_
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`__
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`_
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`__
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`_
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`__
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`_
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`__
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models

View File

@ -45,9 +45,9 @@ the input features.
We have three variants of Emformer models in ``icefall``.
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`_.
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`__.
- ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio,
ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model.
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`_.
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`__.
- ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that
it uses a simplified memory bank. See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless2>`_.

View File

@ -6,10 +6,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
.. Note::
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_,
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_,
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_,
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_,
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
We will take pruned_transducer_stateless4 as an example in this tutorial.
.. HINT::
@ -264,7 +264,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
.. NOTE::
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.

View File

@ -6,7 +6,7 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
.. Note::
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`__,
.. HINT::
@ -642,7 +642,7 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following links:
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models

View File

@ -32,7 +32,7 @@ import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -42,7 +42,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False):
src_dir = Path("data/manifests/aidatatang_200zh")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -109,7 +110,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
@ -119,4 +125,6 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
compute_fbank_aidatatang_200zh(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
)

View File

@ -77,7 +77,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for aidatatang_200zh"
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aidatatang_200zh.py
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
touch data/fbank/.aidatatang_200zh.done
fi
fi

View File

@ -32,7 +32,7 @@ import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -42,7 +42,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -109,7 +110,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
@ -119,4 +125,6 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
compute_fbank_aidatatang_200zh(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
)

View File

@ -32,7 +32,7 @@ import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -42,7 +42,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell(num_mel_bins: int = 80):
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -81,7 +81,8 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -104,7 +105,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
@ -114,4 +120,6 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aishell(num_mel_bins=args.num_mel_bins)
compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
)

View File

@ -114,7 +114,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell"
if [ ! -f data/fbank/.aishell.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell.py
./local/compute_fbank_aishell.py --perturb-speed True
touch data/fbank/.aishell.done
fi
fi

View File

@ -53,7 +53,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aidatatang_200zh"
if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aidatatang_200zh.py
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
touch data/fbank/.aidatatang_200zh_fbank.done
fi
fi

View File

@ -240,7 +240,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
default="pruned_transducer_stateless7/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved

View File

@ -243,7 +243,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
default="pruned_transducer_stateless7/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved

View File

@ -32,7 +32,7 @@ import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -42,7 +42,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell2(num_mel_bins: int = 80):
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -81,7 +81,8 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -104,6 +105,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
@ -114,4 +121,6 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aishell2(num_mel_bins=args.num_mel_bins)
compute_fbank_aishell2(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
)

View File

@ -101,7 +101,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell2"
if [ ! -f data/fbank/.aishell2.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell2.py
./local/compute_fbank_aishell2.py --perturb-speed True
touch data/fbank/.aishell2.done
fi
fi

View File

@ -32,7 +32,7 @@ import torch
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -42,7 +42,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell4(num_mel_bins: int = 80):
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -83,10 +83,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
@ -113,6 +115,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
@ -123,4 +131,6 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aishell4(num_mel_bins=args.num_mel_bins)
compute_fbank_aishell4(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
)

View File

@ -107,7 +107,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for aishell4"
if [ ! -f data/fbank/.aishell4.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py
./local/compute_fbank_aishell4.py --perturb-speed True
touch data/fbank/.aishell4.done
fi
fi

View File

@ -32,7 +32,7 @@ import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -42,7 +42,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80):
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -82,7 +82,8 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -114,6 +115,12 @@ def get_args():
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
@ -124,4 +131,6 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins)
compute_fbank_alimeeting(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
)

View File

@ -97,7 +97,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for alimeeting"
if [ ! -f data/fbank/.alimeeting.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py
./local/compute_fbank_alimeeting.py --perturb-speed True
touch data/fbank/.alimeeting.done
fi
fi

View File

@ -25,6 +25,7 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
from pathlib import Path
@ -39,6 +40,8 @@ from lhotse.features.kaldifeat import (
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
@ -48,7 +51,7 @@ torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_ami():
def compute_fbank_ami(perturb_speed: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
@ -84,8 +87,12 @@ def compute_fbank_ami():
suffix="jsonl.gz",
)
def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
def _extract_feats(
cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool
) -> None:
if speed_perturb:
logging.info(f"Doing speed perturb")
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
_ = cuts.compute_and_store_features_batch(
extractor=extractor,
storage_path=storage_path,
@ -109,6 +116,7 @@ def compute_fbank_ami():
cuts_ihm,
output_dir / "feats_train_ihm",
src_dir / "cuts_train_ihm.jsonl.gz",
perturb_speed,
)
logging.info("Processing train split IHM + reverberated IHM")
@ -117,6 +125,7 @@ def compute_fbank_ami():
cuts_ihm_rvb,
output_dir / "feats_train_ihm_rvb",
src_dir / "cuts_train_ihm_rvb.jsonl.gz",
perturb_speed,
)
logging.info("Processing train split SDM")
@ -129,6 +138,7 @@ def compute_fbank_ami():
cuts_sdm,
output_dir / "feats_train_sdm",
src_dir / "cuts_train_sdm.jsonl.gz",
perturb_speed,
)
logging.info("Processing train split GSS")
@ -141,6 +151,7 @@ def compute_fbank_ami():
cuts_gss,
output_dir / "feats_train_gss",
src_dir / "cuts_train_gss.jsonl.gz",
perturb_speed,
)
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
@ -186,8 +197,21 @@ def compute_fbank_ami():
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_ami()
args = get_args()
compute_fbank_ami(perturb_speed=args.perturb_speed)

View File

@ -85,7 +85,7 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for alimeeting"
mkdir -p data/fbank
python local/compute_fbank_alimeeting.py
python local/compute_fbank_alimeeting.py --perturb-speed True
log "Combine features from train splits"
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
gzip -c > data/manifests/cuts_train_all.jsonl.gz

156
egs/ami/SURT/README.md Normal file
View File

@ -0,0 +1,156 @@
# Introduction
This is a multi-talker ASR recipe for the AMI and ICSI datasets. We train a Streaming
Unmixing and Recognition Transducer (SURT) model for the task.
Please refer to the `egs/libricss/SURT` recipe README for details about the task and the
model.
## Description of the recipe
### Pre-requisites
The recipes in this directory need the following packages to be installed:
- [meeteval](https://github.com/fgnt/meeteval)
- [einops](https://github.com/arogozhnikov/einops)
Additionally, we initialize the model with the pre-trained model from the LibriCSS recipe.
Please download this checkpoint (see below) or train the LibriCSS recipe first.
### Training
To train the model, run the following from within `egs/ami/SURT`:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python dprnn_zipformer/train.py \
--use-fp16 True \
--exp-dir dprnn_zipformer/exp/surt_base \
--world-size 4 \
--max-duration 500 \
--max-duration-valid 250 \
--max-cuts 200 \
--num-buckets 50 \
--num-epochs 30 \
--enable-spec-aug True \
--enable-musan False \
--ctc-loss-scale 0.2 \
--heat-loss-scale 0.2 \
--base-lr 0.004 \
--model-init-ckpt exp/libricss_base.pt \
--chunk-width-randomization True \
--num-mask-encoder-layers 4 \
--num-encoder-layers 2,2,2,2,2
```
The above is for SURT-base (~26M). For SURT-large (~38M), use:
```bash
--model-init-ckpt exp/libricss_large.pt \
--num-mask-encoder-layers 6 \
--num-encoder-layers 2,4,3,2,4 \
--model-init-ckpt exp/zipformer_large.pt \
```
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
### Adaptation
The training step above only trains on simulated mixtures. For best results, we also
adapt the final model on the AMI+ICSI train set. For this, run the following from within
`egs/ami/SURT`:
```bash
export CUDA_VISIBLE_DEVICES="0"
python dprnn_zipformer/train_adapt.py \
--use-fp16 True \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--world-size 4 \
--max-duration 500 \
--max-duration-valid 250 \
--max-cuts 200 \
--num-buckets 50 \
--num-epochs 8 \
--lr-epochs 2 \
--enable-spec-aug True \
--enable-musan False \
--ctc-loss-scale 0.2 \
--base-lr 0.0004 \
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
--chunk-width-randomization True \
--num-mask-encoder-layers 4 \
--num-encoder-layers 2,2,2,2,2
```
For SURT-large, use the following config:
```bash
--num-mask-encoder-layers 6 \
--num-encoder-layers 2,4,3,2,4 \
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
--num-epochs 15 \
--lr-epochs 4 \
```
### Decoding
To decode the model, run the following from within `egs/ami/SURT`:
#### Greedy search
```bash
export CUDA_VISIBLE_DEVICES="0"
python dprnn_zipformer/decode.py \
--epoch 20 --avg 1 --use-averaged-model False \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--max-duration 250 \
--decoding-method greedy_search
```
#### Beam search
```bash
python dprnn_zipformer/decode.py \
--epoch 20 --avg 1 --use-averaged-model False \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--max-duration 250 \
--decoding-method modified_beam_search \
--beam-size 4
```
## Results (using beam search)
**AMI**
| Model | IHM-Mix | SDM | MDM |
|------------|:-------:|:----:|:----:|
| SURT-base | 39.8 | 65.4 | 46.6 |
| + adapt | 37.4 | 46.9 | 43.7 |
| SURT-large | 36.8 | 62.5 | 44.4 |
| + adapt | **35.1** | **44.6** | **41.4** |
**ICSI**
| Model | IHM-Mix | SDM |
|------------|:-------:|:----:|
| SURT-base | 28.3 | 60.0 |
| + adapt | 26.3 | 33.9 |
| SURT-large | 27.8 | 59.7 |
| + adapt | **24.4** | **32.3** |
## Pre-trained models and logs
* LibriCSS pre-trained model (for initialization): [base](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_base) [large](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_large)
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-ami-dprnn-zipformer>
* Training logs:
- surt_base: <https://tensorboard.dev/experiment/8awy98VZSWegLmH4l2JWSA/>
- surt_base_adapt: <https://tensorboard.dev/experiment/aGVgXVzYRDKbGUbPekcNjg/>
- surt_large: <https://tensorboard.dev/experiment/ZXMkez0VSYKbPLqRk4clOQ/>
- surt_large_adapt: <https://tensorboard.dev/experiment/WLKL1e7bTVyEjSonYSNYwg/>

View File

@ -0,0 +1,399 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutMix,
DynamicBucketingSampler,
K2SurtDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AmiAsrDataModule:
"""
DataModule for k2 SURT experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--max-duration-valid",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--max-cuts",
type=int,
default=100,
help="Maximum number of cuts in a single batch. You can "
"reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
sources: bool = False,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SurtDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
return_sources=sources,
strict=False,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
quadratic_duration=30.0,
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
logging.info("About to create dev dataset")
validate = K2SurtDataset(
input_strategy=OnTheFlyFeatures(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
return_sources=False,
strict=False,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration_valid,
quadratic_duration=30.0,
max_cuts=self.args.max_cuts,
shuffle=False,
)
logging.info("About to create dev dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SurtDataset(
input_strategy=OnTheFlyFeatures(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
return_sources=False,
strict=False,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration_valid,
max_cuts=self.args.max_cuts,
shuffle=False,
)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return test_dl
@lru_cache()
def aimix_train_cuts(
self,
rvb_affix: str = "clean",
sources: bool = True,
) -> CutSet:
logging.info("About to get train cuts")
source_affix = "_sources" if sources else ""
cs = load_manifest_lazy(
self.args.manifest_dir / f"cuts_train_{rvb_affix}{source_affix}.jsonl.gz"
)
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
return cs
@lru_cache()
def train_cuts(
self,
) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_ami_icsi.jsonl.gz"
)
@lru_cache()
def ami_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
logging.info(f"About to get AMI {split} {type} cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"cuts_ami-{type}_{split}.jsonl.gz"
)
@lru_cache()
def icsi_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
logging.info(f"About to get ICSI {split} {type} cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"cuts_icsi-{type}_{split}.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/beam_search.py

View File

@ -0,0 +1,622 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./dprnn_zipformer/decode.py \
--epoch 20 \
--avg 1 \
--use-averaged-model false \
--exp-dir ./dprnn_zipformer/exp_adapt \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./dprnn_zipformer/decode.py \
--epoch 20 \
--avg 1 \
--use-averaged-model false \
--exp-dir ./dprnn_zipformer/exp_adapt \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./dprnn_zipformer/decode.py \
--epoch 20 \
--avg 1 \
--use-averaged-model false \
--exp-dir ./dprnn_zipformer/exp_adapt \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AmiAsrDataModule
from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.utils import EPSILON
from train import add_model_arguments, get_params, get_surt_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_surt_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=20,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="dprnn_zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
feature_lens = batch["input_lens"].to(device)
# Apply the mask encoder
B, T, F = feature.shape
processed = model.mask_encoder(feature) # B,T,F*num_channels
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
x_masked = [feature * m for m in masks]
# Recognition
# Stack the inputs along the batch axis
h = torch.cat(x_masked, dim=0)
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
if model.joint_encoder_layer is not None:
encoder_out = model.joint_encoder_layer(encoder_out)
def _group_channels(hyps: List[str]) -> List[List[str]]:
"""
Currently we have a batch of size M*B, where M is the number of
channels and B is the batch size. We need to group the hypotheses
into B groups, each of which contains M hypotheses.
Example:
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
"""
assert len(hyps) == B * params.num_channels
out_hyps = []
for i in range(B):
out_hyps.append(hyps[i::B])
return out_hyps
hyps = []
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp)
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp))
if params.decoding_method == "greedy_search":
return {"greedy_search": _group_channels(hyps)}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: _group_channels(hyps)}
else:
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
cut_ids = [cut.id for cut in batch["cuts"]]
cuts_batch = batch["cuts"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
)
for name, hyps in hyps_dict.items():
this_batch = []
for cut_id, hyp_words in zip(cut_ids, hyps):
# Reference is a list of supervision texts sorted by start time.
ref_words = [
s.text.strip()
for s in sorted(
cuts_batch[cut_id].supervisions, key=lambda s: s.start
)
]
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(cut_ids)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_surt_error_stats(
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
num_channels=params.num_channels,
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LmScorer.add_arguments(parser)
AmiAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
), f"Decoding method {params.decoding_method} is not supported."
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_surt_model(params)
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
model.encoder.decode_chunk_size,
params.decode_chunk_len,
)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
ami = AmiAsrDataModule(args)
# NOTE(@desh2608): we filter segments longer than 120s to avoid OOM errors in decoding.
# However, 99.9% of the segments are shorter than 120s, so this should not
# substantially affect the results. In future, we will implement an overlapped
# inference method to avoid OOM errors.
test_sets = {}
for split in ["dev", "test"]:
for type in ["ihm-mix", "sdm", "mdm8-bf"]:
test_sets[f"ami-{split}_{type}"] = (
ami.ami_cuts(split=split, type=type)
.trim_to_supervision_groups(max_pause=0.0)
.filter(lambda c: 0.1 < c.duration < 120.0)
.to_eager()
)
for split in ["dev", "test"]:
for type in ["ihm-mix", "sdm"]:
test_sets[f"icsi-{split}_{type}"] = (
ami.icsi_cuts(split=split, type=type)
.trim_to_supervision_groups(max_pause=0.0)
.filter(lambda c: 0.1 < c.duration < 120.0)
.to_eager()
)
for test_set, test_cuts in test_sets.items():
test_dl = ami.test_dataloaders(test_cuts)
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/dprnn.py

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/encoder_interface.py

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/export.py

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/model.py

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/optim.py

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../libricss/SURT/dprnn_zipformer/zipformer.py

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file adds source features as temporal arrays to the mixture manifests.
It looks for manifests in the directory data/manifests.
"""
import logging
from pathlib import Path
import numpy as np
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
from tqdm import tqdm
def add_source_feats():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
logging.info("Reading mixed cuts")
mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz")
mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz")
logging.info("Reading source cuts")
source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz")
logging.info("Adding source features to the mixed cuts")
pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features")
with CutSet.open_writer(
src_dir / "cuts_train_clean_sources.jsonl.gz"
) as cut_writer_clean, CutSet.open_writer(
src_dir / "cuts_train_reverb_sources.jsonl.gz"
) as cut_writer_reverb, LilcomChunkyWriter(
output_dir / "feats_train_clean_sources"
) as source_feat_writer:
for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb):
assert cut_reverb.id == cut_clean.id + "_rvb"
source_feats = []
source_feat_offsets = []
cur_offset = 0
for sup in sorted(
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
):
source_cut = source_cuts[sup.id]
source_feats.append(source_cut.load_features())
source_feat_offsets.append(cur_offset)
cur_offset += source_cut.num_frames
cut_clean.source_feats = source_feat_writer.store_array(
cut_clean.id, np.concatenate(source_feats, axis=0)
)
cut_clean.source_feat_offsets = source_feat_offsets
cut_writer_clean.write(cut_clean)
# Also write the reverb cut
cut_reverb.source_feats = cut_clean.source_feats
cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets
cut_writer_reverb.write(cut_reverb)
pbar.update(1)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
add_source_feats()

View File

@ -0,0 +1,185 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the synthetically mixed AMI and ICSI
train set.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import random
import warnings
from pathlib import Path
import torch
import torch.multiprocessing
import torchaudio
from lhotse import (
AudioSource,
LilcomChunkyWriter,
Recording,
load_manifest,
load_manifest_lazy,
)
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
from lhotse.cut import MixedCut, MixTrack, MultiCut
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.utils import fix_random_seed, uuid4
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
torchaudio.set_audio_backend("soundfile")
set_ffmpeg_torchaudio_info_enabled(False)
def compute_fbank_aimix():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz")
# only uses RIRs and noises from REVERB challenge
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
lambda r: "RVB2014" in r.id
)
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
lambda r: "RVB2014" in r.id
)
# Apply perturbation to the training cuts
logging.info("Applying perturbation to the training cuts")
train_cuts_rvb = train_cuts.map(
lambda c: augment(
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
)
)
logging.info("Extracting fbank features for training cuts")
_ = train_cuts.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / "ai-mix_feats_clean",
manifest_path=src_dir / "cuts_train_clean.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
_ = train_cuts_rvb.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / "ai-mix_feats_reverb",
manifest_path=src_dir / "cuts_train_reverb.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
"""
Given a mixed cut, this function optionally applies the following augmentations:
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
- Reverberation using a randomly selected RIR
- Adding noise
- Perturbing the loudness (in range [-20, -25] dB)
"""
out_cut = cut.drop_features()
# Perturb the SNRs (optional)
if perturb_snr:
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
if i == 0:
# Skip the first track since it is the reference
continue
track.snr = snr
# Reverberate the cut (optional)
if rirs is not None:
# Select an RIR at random
rir = random.choice(rirs)
# Select a channel at random
rir_channel = random.choice(list(range(rir.num_channels)))
# Reverberate the cut
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
# Add noise (optional)
if noises is not None:
# Select a noise recording at random
noise = random.choice(noises).to_cut()
if isinstance(noise, MultiCut):
noise = noise.to_mono()[0]
# Select an SNR at random
snr = random.uniform(10, 30)
# Repeat the noise to match the duration of the cut
noise = repeat_cut(noise, out_cut.duration)
out_cut = MixedCut(
id=out_cut.id,
tracks=[
MixTrack(cut=out_cut, type="MixedCut"),
MixTrack(cut=noise, type="DataCut", snr=snr),
],
)
# Perturb the loudness (optional)
if perturb_loudness:
target_loudness = random.uniform(-20, -25)
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
return out_cut
def repeat_cut(cut, duration):
while cut.duration < duration:
cut = cut.mix(cut, offset_other_by=cut.duration)
return cut.truncate(duration=duration)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
fix_random_seed(42)
compute_fbank_aimix()

View File

@ -0,0 +1,94 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the AMI dataset.
We compute features for full recordings (i.e., without trimming to supervisions).
This way we can create arbitrary segmentations later.
The generated fbank features are saved in data/fbank.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_ami():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests = {}
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
manifests[part] = read_manifests_if_cached(
dataset_parts=["train", "dev", "test"],
output_dir=src_dir,
prefix=f"ami-{part}",
suffix="jsonl.gz",
)
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
for split in ["train", "dev", "test"]:
logging.info(f"Processing {part} {split}")
cuts = CutSet.from_manifests(
**manifests[part][split]
).compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"ami-{part}_{split}_feats",
manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_ami()

View File

@ -0,0 +1,95 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the ICSI dataset.
We compute features for full recordings (i.e., without trimming to supervisions).
This way we can create arbitrary segmentations later.
The generated fbank features are saved in data/fbank.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_icsi():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests = {}
for part in ["ihm-mix", "sdm"]:
manifests[part] = read_manifests_if_cached(
dataset_parts=["train"],
output_dir=src_dir,
prefix=f"icsi-{part}",
suffix="jsonl.gz",
)
for part in ["ihm-mix", "sdm"]:
for split in ["train"]:
logging.info(f"Processing {part} {split}")
cuts = CutSet.from_manifests(
**manifests[part][split]
).compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"icsi-{part}_{split}_feats",
manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_icsi()

View File

@ -0,0 +1,101 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the trimmed sub-segments which will be
used for simulating the training mixtures.
The generated fbank features are saved in data/fbank.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
import torchaudio
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
torchaudio.set_audio_backend("soundfile")
set_ffmpeg_torchaudio_info_enabled(False)
def compute_fbank_ihm():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests = {}
for data in ["ami", "icsi"]:
manifests[data] = read_manifests_if_cached(
dataset_parts=["train"],
output_dir=src_dir,
types=["recordings", "supervisions"],
prefix=f"{data}-ihm",
suffix="jsonl.gz",
)
logging.info("Computing features")
for data in ["ami", "icsi"]:
cs = CutSet.from_manifests(**manifests[data]["train"])
cs = cs.trim_to_supervisions(keep_overlapping=False)
cs = cs.normalize_loudness(target=-23.0, affix_id=False)
cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1)
_ = cs.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"{data}-ihm_train_feats",
manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_ihm()

View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file creates AMI train segments.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import LilcomChunkyWriter, load_manifest_lazy
from lhotse.cut import Cut, CutSet
from lhotse.utils import EPSILON, add_durations
from tqdm import tqdm
def cut_into_windows(cuts: CutSet, duration: float):
"""
This function takes a CutSet and cuts each cut into windows of roughly
`duration` seconds. By roughly, we mean that we try to adjust for the last supervision
that exceeds the duration, or is shorter than the duration.
"""
res = []
with tqdm() as pbar:
for cut in cuts:
pbar.update(1)
sups = cut.index_supervisions()[cut.id]
sr = cut.sampling_rate
start = 0.0
end = duration
num_tries = 0
while start < cut.duration and num_tries < 2:
# Find the supervision that are cut by the window endpoint
hitlist = [iv for iv in sups.at(end) if iv.begin < end]
# If there are no supervisions, we are done
if not hitlist:
res.append(
cut.truncate(
offset=start,
duration=add_durations(end, -start, sampling_rate=sr),
keep_excessive_supervisions=False,
)
)
# Update the start and end for the next window
start = end
end = add_durations(end, duration, sampling_rate=sr)
else:
# find ratio of durations cut by the window endpoint
ratios = [
add_durations(end, -iv.end, sampling_rate=sr) / iv.length()
for iv in hitlist
]
# we retain the supervisions that have >50% of their duration
# in the window, and discard the others
retained = []
discarded = []
for iv, ratio in zip(hitlist, ratios):
if ratio > 0.5:
retained.append(iv)
else:
discarded.append(iv)
cur_end = max(iv.end for iv in retained) if retained else end
res.append(
cut.truncate(
offset=start,
duration=add_durations(cur_end, -start, sampling_rate=sr),
keep_excessive_supervisions=False,
)
)
# For the next window, we start at the earliest discarded supervision
next_start = min(iv.begin for iv in discarded) if discarded else end
next_end = add_durations(next_start, duration, sampling_rate=sr)
# It may happen that next_start is the same as start, in which case
# we will advance the window anyway
if next_start == start:
logging.warning(
f"Next start is the same as start: {next_start} == {start} for cut {cut.id}"
)
start = end + EPSILON
end = add_durations(start, duration, sampling_rate=sr)
num_tries += 1
else:
start = next_start
end = next_end
return CutSet.from_cuts(res)
def prepare_train_cuts():
src_dir = Path("data/manifests")
logging.info("Loading the manifests")
train_cuts_ihm = load_manifest_lazy(
src_dir / "cuts_ami-ihm-mix_train.jsonl.gz"
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map(
lambda c: c.with_id(f"{c.id}_sdm")
)
train_cuts_mdm = load_manifest_lazy(
src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz"
).map(lambda c: c.with_id(f"{c.id}_mdm8-bf"))
# Combine all cuts into one CutSet
train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
# Combine the two segmentations
train_all = train_cuts_1 + train_cuts_2
# At this point, some of the cuts may be very long. We will cut them into windows of
# roughly 30 seconds.
logging.info("Cutting the segments into windows of 30 seconds")
train_all_30 = cut_into_windows(train_all, duration=30.0)
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
# Show statistics
train_all.describe(full=True)
# Save the cuts
logging.info("Saving the cuts")
train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_train_cuts()

View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file creates ICSI train segments.
"""
import logging
from pathlib import Path
from lhotse import load_manifest_lazy
from prepare_ami_train_cuts import cut_into_windows
def prepare_train_cuts():
src_dir = Path("data/manifests")
logging.info("Loading the manifests")
train_cuts_ihm = load_manifest_lazy(
src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz"
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map(
lambda c: c.with_id(f"{c.id}_sdm")
)
# Combine all cuts into one CutSet
train_cuts = train_cuts_ihm + train_cuts_sdm
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
# Combine the two segmentations
train_all = train_cuts_1 + train_cuts_2
# At this point, some of the cuts may be very long. We will cut them into windows of
# roughly 30 seconds.
logging.info("Cutting the segments into windows of 30 seconds")
train_all_30 = cut_into_windows(train_all, duration=30.0)
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
# Show statistics
train_all.describe(full=True)
# Save the cuts
logging.info("Saving the cuts")
train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_train_cuts()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_bpe.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

195
egs/ami/SURT/prepare.sh Executable file
View File

@ -0,0 +1,195 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/ami
# You can find audio and transcripts for AMI in this path.
#
# - $dl_dir/icsi
# You can find audio and transcripts for ICSI in this path.
#
# - $dl_dir/rirs_noises
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
#
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
vocab_size=500
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/amicorpus,
# you can create a symlink
#
# ln -sfv /path/to/amicorpus $dl_dir/amicorpus
#
if [ ! -d $dl_dir/amicorpus ]; then
for mic in ihm ihm-mix sdm mdm8-bf; do
lhotse download ami --mic $mic $dl_dir/amicorpus
done
fi
# If you have pre-downloaded it to /path/to/icsi,
# you can create a symlink
#
# ln -sfv /path/to/icsi $dl_dir/icsi
#
if [ ! -d $dl_dir/icsi ]; then
lhotse download icsi $dl_dir/icsi
fi
# If you have pre-downloaded it to /path/to/rirs_noises,
# you can create a symlink
#
# ln -sfv /path/to/rirs_noises $dl_dir/
#
if [ ! -d $dl_dir/rirs_noises ]; then
lhotse download rirs_noises $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare AMI manifests"
# We assume that you have downloaded the AMI corpus
# to $dl_dir/amicorpus. We perform text normalization for the transcripts.
mkdir -p data/manifests
for mic in ihm ihm-mix sdm mdm8-bf; do
log "Preparing AMI manifest for $mic"
lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare ICSI manifests"
# We assume that you have downloaded the ICSI corpus
# to $dl_dir/icsi. We perform text normalization for the transcripts.
mkdir -p data/manifests
log "Preparing ICSI manifest"
for mic in ihm ihm-mix sdm; do
lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/
done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare RIRs"
# We assume that you have downloaded the RIRS_NOISES corpus
# to $dl_dir/rirs_noises
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 3: Extract features for AMI and ICSI recordings"
python local/compute_fbank_ami.py
python local/compute_fbank_icsi.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Create sources for simulating mixtures"
# In the following script, we speed-perturb the IHM recordings and extract features.
python local/compute_fbank_ihm.py
lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \
data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\
lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\
lhotse filter 'duration<=12.0' - - |\
shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Create training mixtures"
lhotse workflows simulate-meetings \
--method conversational \
--same-spk-pause 0.5 \
--diff-spk-pause 0.5 \
--diff-spk-overlap 1.0 \
--prob-diff-spk-overlap 0.8 \
--num-meetings 200000 \
--num-speakers-per-meeting 2,3 \
--max-duration-per-speaker 15.0 \
--max-utterances-per-speaker 3 \
--seed 1234 \
--num-jobs 2 \
data/manifests/ihm_cuts_train_trimmed.jsonl.gz \
data/manifests/ai-mix_cuts_clean.jsonl.gz
python local/compute_fbank_aimix.py
# Add source features to the manifest (will be used for masking loss)
# This may take ~2 hours.
python local/add_source_feats.py
# Combine clean and reverb
cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \
<(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\
shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Create training mixtures from real sessions"
python local/prepare_ami_train_cuts.py
python local/prepare_icsi_train_cuts.py
# Combine AMI and ICSI
cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \
<(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\
shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)."
mkdir -p data/lm
cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
<(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
> data/lm/transcript_words.txt
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)"
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# Add special words to words.txt
echo "<eps> 0" > $lang_dir/words.txt
echo "!SIL 1" >> $lang_dir/words.txt
echo "<UNK> 2" >> $lang_dir/words.txt
# Add regular words to words.txt
cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
# Add remaining special word symbols expected by LM scripts.
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "<s> ${num_words}" >> $lang_dir/words.txt
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "</s> ${num_words}" >> $lang_dir/words.txt
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "#0 ${num_words}" >> $lang_dir/words.txt
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript data/lm/transcript_words.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
fi
fi

1
egs/ami/SURT/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -45,7 +45,7 @@ def get_args():
def normalize_text(utt: str) -> str:
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
return re.sub(r"[^a-zA-Z\s']", "", utt).upper()
def preprocess_commonvoice(

249
egs/libricss/SURT/README.md Normal file
View File

@ -0,0 +1,249 @@
# Introduction
This is a multi-talker ASR recipe for the LibriCSS dataset. We train a Streaming
Unmixing and Recognition Transducer (SURT) model for the task. In this README,
we will describe the task, the model, and the training process. We will also
provide links to pre-trained models and training logs.
## Task
LibriCSS is a multi-talker meeting corpus formed from mixing together LibriSpeech utterances
and replaying in a real meeting room. It consists of 10 1-hour sessions of audio, each
recorded on a 7-channel microphone. The sessions are recorded at a sampling rate of 16 kHz.
For more information, refer to the paper:
Z. Chen et al., "Continuous speech separation: dataset and analysis,"
ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),
Barcelona, Spain, 2020
In this recipe, we perform the "continuous, streaming, multi-talker ASR" task on LibriCSS.
* By "continuous", we mean that the model should be able to transcribe unsegmented audio
without the need of an external VAD.
* By "streaming", we mean that the model has limited right context. We use a right-context
of at most 32 frames (320 ms).
* By "multi-talker", we mean that the model should be able to transcribe overlapping speech
from multiple speakers.
For now, we do not care about speaker attribution, i.e., the transcription is speaker
agnostic. The evaluation depends on the particular model type. In this case, we use
the optimal reference combination WER (ORC-WER) metric as implemented in the
[meeteval](https://github.com/fgnt/meeteval) toolkit.
## Model
We use the Streaming Unmixing and Recognition Transducer (SURT) model for this task.
The model is based on the papers:
- Lu, Liang et al. “Streaming End-to-End Multi-Talker Speech Recognition.” IEEE Signal Processing Letters 28 (2020): 803-807.
- Raj, Desh et al. “Continuous Streaming Multi-Talker ASR with Dual-Path Transducers.” ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2021): 7317-7321.
The model is a combination of a speech separation model and a speech recognition model,
but trained end-to-end with a single loss function. The overall architecture is shown
in the figure below. Note that this architecture is slightly different from the one
in the above papers. A detailed description of the model can be found in the following
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
<p align="center">
<img src="surt.png">
Streaming Unmixing and Recognition Transducer
</p>
In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
and a Zipfomer-based recognition network. But other combinations are possible as well.
## Training objective
We train the model using the pruned transducer loss, similar to other ASR recipes in
icefall. However, an important consideration is how to assign references to the output
channels (2 in this case). For this, we use the heuristic error assignment training (HEAT)
strategy, which assigns references to the first available channel based on their start
times. An illustrative example is shown in the figure below:
<p align="center">
<img src="heat.png">
Illustration of HEAT-based reference assignment.
</p>
## Description of the recipe
### Pre-requisites
The recipes in this directory need the following packages to be installed:
- [meeteval](https://github.com/fgnt/meeteval)
- [einops](https://github.com/arogozhnikov/einops)
Additionally, we initialize the "recognition" transducer with a pre-trained model,
trained on LibriSpeech. For this, please run the following from within `egs/librispeech/ASR`:
```bash
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python pruned_transducer_stateless7_streaming/train.py \
--use-fp16 True \
--exp-dir pruned_transducer_stateless7_streaming/exp \
--world-size 4 \
--max-duration 800 \
--num-epochs 10 \
--keep-last-k 1 \
--manifest-dir data/manifests \
--enable-musan true \
--master-port 54321 \
--bpe-model data/lang_bpe_500/bpe.model \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dims 768,768,768,768,768 \
--nhead 8,8,8,8,8 \
--encoder-dims 256,256,256,256,256 \
--attention-dims 192,192,192,192,192 \
--encoder-unmasked-dims 192,192,192,192,192 \
--zipformer-downsampling-factors 1,2,4,8,2 \
--cnn-module-kernels 31,31,31,31,31 \
--decoder-dim 512 \
--joiner-dim 512
```
The above is for SURT-base (~26M). For SURT-large (~38M), use `--num-encoder-layers 2,4,3,2,4`.
Once the above model is trained for 10 epochs, copy it to `egs/libricss/SURT/exp`:
```bash
cp -r pruned_transducer_stateless7_streaming/exp/epoch-10.pt exp/zipformer_base.pt
```
**NOTE:** We also provide this pre-trained checkpoint (see the section below), so you can skip
the above step if you want.
### Training
To train the model, run the following from within `egs/libricss/SURT`:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python dprnn_zipformer/train.py \
--use-fp16 True \
--exp-dir dprnn_zipformer/exp/surt_base \
--world-size 4 \
--max-duration 500 \
--max-duration-valid 250 \
--max-cuts 200 \
--num-buckets 50 \
--num-epochs 30 \
--enable-spec-aug True \
--enable-musan False \
--ctc-loss-scale 0.2 \
--heat-loss-scale 0.2 \
--base-lr 0.004 \
--model-init-ckpt exp/zipformer_base.pt \
--chunk-width-randomization True \
--num-mask-encoder-layers 4 \
--num-encoder-layers 2,2,2,2,2
```
The above is for SURT-base (~26M). For SURT-large (~38M), use:
```bash
--num-mask-encoder-layers 6 \
--num-encoder-layers 2,4,3,2,4 \
--model-init-ckpt exp/zipformer_large.pt \
```
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
### Adaptation
The training step above only trains on simulated mixtures. For best results, we also
adapt the final model on the LibriCSS dev set. For this, run the following from within
`egs/libricss/SURT`:
```bash
export CUDA_VISIBLE_DEVICES="0"
python dprnn_zipformer/train_adapt.py \
--use-fp16 True \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--world-size 1 \
--max-duration 500 \
--max-duration-valid 250 \
--max-cuts 200 \
--num-buckets 50 \
--num-epochs 8 \
--lr-epochs 2 \
--enable-spec-aug True \
--enable-musan False \
--ctc-loss-scale 0.2 \
--base-lr 0.0004 \
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
--chunk-width-randomization True \
--num-mask-encoder-layers 4 \
--num-encoder-layers 2,2,2,2,2
```
For SURT-large, use the following config:
```bash
--num-mask-encoder-layers 6 \
--num-encoder-layers 2,4,3,2,4 \
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
--num-epochs 15 \
--lr-epochs 4 \
```
### Decoding
To decode the model, run the following from within `egs/libricss/SURT`:
#### Greedy search
```bash
export CUDA_VISIBLE_DEVICES="0"
python dprnn_zipformer/decode.py \
--epoch 8 --avg 1 --use-averaged-model False \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--max-duration 250 \
--decoding-method greedy_search
```
#### Beam search
```bash
python dprnn_zipformer/decode.py \
--epoch 8 --avg 1 --use-averaged-model False \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--max-duration 250 \
--decoding-method modified_beam_search \
--beam-size 4
```
## Results (using beam search)
#### IHM-Mix
| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
| dprnn_zipformer (base) | 26.7 | 5.1 | 4.2 | 13.7 | 18.7 | 20.5 | 20.6 | 13.8 |
| dprnn_zipformer (large) | 37.9 | 4.6 | 3.8 | 12.7 | 14.3 | 16.7 | 21.2 | 12.2 |
#### SDM
| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
| dprnn_zipformer (base) | 26.7 | 6.8 | 7.2 | 21.4 | 24.5 | 28.6 | 31.2 | 20.0 |
| dprnn_zipformer (large) | 37.9 | 6.4 | 6.9 | 17.9 | 19.7 | 25.2 | 25.5 | 16.9 |
## Pre-trained models and logs
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer>
* Training logs:
- surt_base: <https://tensorboard.dev/experiment/YLGJTkBETb2aqDQ61jbxvQ/>
- surt_base_adapt: <https://tensorboard.dev/experiment/pjXMFVL9RMej85rMHyd0EQ/>
- surt_large: <https://tensorboard.dev/experiment/82HvYqfrSOKZ4w8Jod2QMw/>
- surt_large_adapt: <https://tensorboard.dev/experiment/5oIdEgRqS9Wb6yVuxaExEw/>

View File

@ -0,0 +1,372 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# Copyright 2023 Johns Hopkins Univrtsity (Author: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutMix,
DynamicBucketingSampler,
K2SurtDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriCssAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--max-duration-valid",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--max-cuts",
type=int,
default=100,
help="Maximum number of cuts in a single batch. You can "
"reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
return_sources: bool = True,
strict: bool = True,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SurtDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
return_sources=return_sources,
strict=strict,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
quadratic_duration=30.0,
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
logging.info("About to create dev dataset")
validate = K2SurtDataset(
input_strategy=OnTheFlyFeatures(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
return_sources=False,
strict=False,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration_valid,
max_cuts=self.args.max_cuts,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SurtDataset(
input_strategy=OnTheFlyFeatures(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
return_sources=False,
strict=False,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration_valid,
max_cuts=self.args.max_cuts,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def lsmix_cuts(
self,
rvb_affix: str = "clean",
type_affix: str = "full",
sources: bool = True,
) -> CutSet:
logging.info("About to get train cuts")
source_affix = "_sources" if sources else ""
cs = load_manifest_lazy(
self.args.manifest_dir
/ f"cuts_train_{rvb_affix}_{type_affix}{source_affix}.jsonl.gz"
)
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
return cs
@lru_cache()
def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
logging.info(f"About to get LibriCSS {split} {type} cuts")
cs = load_manifest_lazy(
self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
)
return cs

View File

@ -0,0 +1,730 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import k2
import torch
from model import SURT
from icefall import NgramLmStateCost
from icefall.utils import DecodingResults
def greedy_search(
model: SURT,
encoder_out: torch.Tensor,
max_sym_per_frame: int,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance.
Args:
model:
An instance of `SURT`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 4
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = next(model.parameters()).device
decoder_input = torch.tensor(
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
T = encoder_out.size(1)
t = 0
hyp = [blank_id] * context_size
# timestamp[i] is the frame index after subsampling
# on which hyp[i] is decoded
timestamp = []
# Maximum symbols per utterance.
max_sym_per_utt = 1000
# symbols per frame
sym_per_frame = 0
# symbols per utterance decoded so far
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y not in (blank_id, unk_id):
hyp.append(y)
timestamp.append(t)
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
sym_per_utt += 1
sym_per_frame += 1
else:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks
if not return_timestamps:
return hyp
else:
return DecodingResults(
hyps=[hyp],
timestamps=[timestamp],
)
def greedy_search_batch(
model: SURT,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The SURT model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = next(model.parameters()).device
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out: (N, 1, decoder_out_dim)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v not in (blank_id, unk_id):
hyps[i].append(v)
timestamps[i].append(t)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
hyps=ans,
timestamps=ans_timestamps,
)
def modified_beam_search(
model: SURT,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The SURT model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
hyps=ans,
timestamps=ans_timestamps,
)
def beam_search(
model: SURT,
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_SURT.py#L247 is used as a reference.
Args:
model:
An instance of `SURT`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size,
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
T = encoder_out.size(1)
t = 0
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
max_sym_per_utt = 20000
sym_per_utt = 0
decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
A = B
B = HypothesisList()
joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True:
y_star = A.get_most_probable()
A.remove(y_star)
cached_key = y_star.key
if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-context_size:]],
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
decoder_cache[cached_key] = decoder_out
else:
decoder_out = decoder_cache[cached_key]
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(
current_encoder_out,
decoder_out.unsqueeze(1),
project_input=False,
)
# TODO(fangjun): Scale the blank posterior
log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(
Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
timestamp=y_star.timestamp[:],
)
)
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i in (blank_id, unk_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_timestamp = y_star.timestamp + [t]
A.add(
Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
)
)
# Check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = A.get_most_probable()
kept_B = B.filter(A_most_probable.log_prob)
if len(kept_B) >= beam:
B = kept_B.topk(beam)
break
t += 1
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
if not return_timestamps:
return ys
else:
return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans

Some files were not shown because too many files have changed in this diff Show More