mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Merge branch 'dev_swbd' of https://github.com/JinZr/icefall into dev_swbd
This commit is contained in:
commit
58d9088010
@ -21,9 +21,9 @@ tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/L.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/LG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/Linv.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/cpu_jit.pt"
|
45
.github/workflows/build-docker-image.yml
vendored
Normal file
45
.github/workflows/build-docker-image.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
# see also
|
||||
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
|
||||
name: Build docker image
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: build_docker-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-docker-image:
|
||||
name: ${{ matrix.image }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
|
||||
steps:
|
||||
# refer to https://github.com/actions/checkout
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Rename
|
||||
shell: bash
|
||||
run: |
|
||||
image=${{ matrix.image }}
|
||||
mv -v ./docker/$image.dockerfile ./Dockerfile
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
push: true
|
||||
tags: k2fsa/icefall:${{ matrix.image }}
|
4
.github/workflows/run-aishell-2022-06-20.yml
vendored
4
.github/workflows/run-aishell-2022-06-20.yml
vendored
@ -44,7 +44,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -119,5 +119,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-06-20
|
||||
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-06-20
|
||||
path: egs/aishell/ASR/pruned_transducer_stateless3/exp/
|
||||
|
92
.github/workflows/run-docker-image.yml
vendored
Normal file
92
.github/workflows/run-docker-image.yml
vendored
Normal file
@ -0,0 +1,92 @@
|
||||
name: Run docker image
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: run_docker_image-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run-docker-image:
|
||||
name: ${{ matrix.image }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
steps:
|
||||
# refer to https://github.com/actions/checkout
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Run the build process with Docker
|
||||
uses: addnab/docker-run-action@v3
|
||||
with:
|
||||
image: k2fsa/icefall:${{ matrix.image }}
|
||||
shell: bash
|
||||
run: |
|
||||
uname -a
|
||||
cat /etc/*release
|
||||
|
||||
nvcc --version
|
||||
|
||||
# For torch1.9.0-cuda10.2
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-10.2/compat:$LD_LIBRARY_PATH
|
||||
|
||||
# For torch1.12.1-cuda11.3
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-11.3/compat:$LD_LIBRARY_PATH
|
||||
|
||||
# For torch2.0.0-cuda11.7
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/compat:$LD_LIBRARY_PATH
|
||||
|
||||
|
||||
which nvcc
|
||||
cuda_dir=$(dirname $(which nvcc))
|
||||
echo "cuda_dir: $cuda_dir"
|
||||
|
||||
find $cuda_dir -name libcuda.so*
|
||||
echo "--------------------"
|
||||
|
||||
find / -name libcuda.so* 2>/dev/null
|
||||
|
||||
# for torch1.13.0-cuda11.6
|
||||
if [ -e /opt/conda/lib/stubs/libcuda.so ]; then
|
||||
cd /opt/conda/lib/stubs && ln -s libcuda.so libcuda.so.1 && cd -
|
||||
export LD_LIBRARY_PATH=/opt/conda/lib/stubs:$LD_LIBRARY_PATH
|
||||
fi
|
||||
|
||||
find / -name libcuda.so* 2>/dev/null
|
||||
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||
|
||||
python3 --version
|
||||
which python3
|
||||
|
||||
python3 -m pip list
|
||||
|
||||
echo "----------torch----------"
|
||||
python3 -m torch.utils.collect_env
|
||||
|
||||
echo "----------k2----------"
|
||||
python3 -c "import k2; print(k2.__file__)"
|
||||
python3 -c "import k2; print(k2.__dev_version__)"
|
||||
python3 -m k2.version
|
||||
|
||||
echo "----------lhotse----------"
|
||||
python3 -c "import lhotse; print(lhotse.__file__)"
|
||||
python3 -c "import lhotse; print(lhotse.__version__)"
|
||||
|
||||
echo "----------kaldifeat----------"
|
||||
python3 -c "import kaldifeat; print(kaldifeat.__file__)"
|
||||
python3 -c "import kaldifeat; print(kaldifeat.__version__)"
|
||||
|
||||
echo "Test yesno recipe"
|
||||
|
||||
cd egs/yesno/ASR
|
||||
|
||||
./prepare.sh
|
||||
|
||||
./tdnn/train.py
|
||||
|
||||
./tdnn/decode.py
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -122,5 +122,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
|
||||
path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless-2022-03-12
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -174,12 +174,12 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
||||
- name: Upload decoding results for pruned_transducer_stateless3
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless5-2022-05-13
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless5-2022-05-13
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless5/exp/
|
||||
|
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-2022-11-11
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/
|
||||
|
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless8-2022-11-14
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless8-2022-11-14
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/
|
||||
|
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-2022-12-01
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/
|
||||
|
@ -163,5 +163,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer_mmi-2022-12-08
|
||||
path: egs/librispeech/ASR/zipformer_mmi/exp/
|
||||
|
@ -168,5 +168,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-streaming-2022-12-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-streaming-2022-12-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-librispeech-2022-12-15-stateless7-ctc-bs
|
||||
name: run-librispeech-2023-01-29-stateless7-ctc-bs
|
||||
# zipformer
|
||||
|
||||
on:
|
||||
@ -34,7 +34,7 @@ on:
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
jobs:
|
||||
run_librispeech_2022_12_15_zipformer_ctc_bs:
|
||||
run_librispeech_2023_01_29_zipformer_ctc_bs:
|
||||
if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
@ -124,7 +124,7 @@ jobs:
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh
|
||||
.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh
|
||||
|
||||
- name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2023-01-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/
|
@ -151,5 +151,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-conformer_ctc3-2022-11-28
|
||||
path: egs/librispeech/ASR/conformer_ctc3/exp/
|
||||
|
@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03
|
||||
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -153,5 +153,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-06-26
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-06-26
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
@ -170,5 +170,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless2-2022-04-19
|
||||
path: egs/librispeech/ASR/transducer_stateless2/exp/
|
||||
|
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
||||
|
@ -151,5 +151,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
|
||||
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
|
||||
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless-2022-02-07
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless-2022-02-07
|
||||
path: egs/librispeech/ASR/transducer_stateless/exp/
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
2
.github/workflows/run-yesno-recipe.yml
vendored
2
.github/workflows/run-yesno-recipe.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# os: [ubuntu-18.04, macos-10.15]
|
||||
# os: [ubuntu-latest, macos-10.15]
|
||||
# TODO: enable macOS for CPU testing
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
57
.github/workflows/test.yml
vendored
57
.github/workflows/test.yml
vendored
@ -35,9 +35,9 @@ jobs:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.8"]
|
||||
torch: ["1.10.0"]
|
||||
torchaudio: ["0.10.0"]
|
||||
k2-version: ["1.23.2.dev20221201"]
|
||||
torch: ["1.13.0"]
|
||||
torchaudio: ["0.13.0"]
|
||||
k2-version: ["1.24.3.dev20230719"]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
@ -66,14 +66,14 @@ jobs:
|
||||
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
|
||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.github.io/k2/cpu.html
|
||||
pip install git+https://github.com/lhotse-speech/lhotse
|
||||
# icefall requirements
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
pip install kaldifst
|
||||
pip install onnxruntime
|
||||
pip install onnxruntime matplotlib
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Install graphviz
|
||||
@ -83,13 +83,6 @@ jobs:
|
||||
python3 -m pip install -qq graphviz
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Install graphviz
|
||||
if: startsWith(matrix.os, 'macos')
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install -qq graphviz
|
||||
brew install -q graphviz
|
||||
|
||||
- name: Run tests
|
||||
if: startsWith(matrix.os, 'ubuntu')
|
||||
run: |
|
||||
@ -129,40 +122,10 @@ jobs:
|
||||
cd ../transducer_lstm
|
||||
pytest -v -s
|
||||
|
||||
- name: Run tests
|
||||
if: startsWith(matrix.os, 'macos')
|
||||
run: |
|
||||
ls -lh
|
||||
export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
|
||||
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
|
||||
echo "lib_path: $lib_path"
|
||||
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
||||
pytest -v -s ./test
|
||||
|
||||
# run tests for conformer ctc
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
cd ../zipformer
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless3
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless4
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
# cd ../transducer
|
||||
# pytest -v -s
|
||||
|
||||
cd ../transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_lstm
|
||||
pytest -v -s
|
||||
- uses: actions/upload-artifact@v2
|
||||
with:
|
||||
path: egs/librispeech/ASR/zipformer/swoosh.pdf
|
||||
name: swoosh.pdf
|
||||
|
@ -1,5 +1,20 @@
|
||||
# icefall dockerfile
|
||||
|
||||
## Download from dockerhub
|
||||
|
||||
You can find pre-built docker image for icefall at the following address:
|
||||
|
||||
<https://hub.docker.com/r/k2fsa/icefall/tags>
|
||||
|
||||
Example usage:
|
||||
|
||||
```bash
|
||||
docker run --gpus all --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
|
||||
```
|
||||
|
||||
|
||||
## Build from dockerfile
|
||||
|
||||
2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
|
||||
|
||||
If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
|
||||
|
70
docker/torch1.12.1-cuda11.3.dockerfile
Normal file
70
docker/torch1.12.1-cuda11.3.dockerfile
Normal file
@ -0,0 +1,70 @@
|
||||
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
|
||||
|
||||
ENV LC_ALL C.UTF-8
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG K2_VERSION="1.24.3.dev20230725+cuda11.3.torch1.12.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.3.torch1.12.1"
|
||||
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
LABEL k2_version=${K2_VERSION}
|
||||
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
vim \
|
||||
libssl-dev \
|
||||
autoconf \
|
||||
automake \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
gfortran \
|
||||
git \
|
||||
libtool \
|
||||
make \
|
||||
patch \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
valgrind \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||
git+https://github.com/lhotse-speech/lhotse \
|
||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||
\
|
||||
kaldi_native_io \
|
||||
kaldialign \
|
||||
kaldifst \
|
||||
kaldilm \
|
||||
sentencepiece>=0.1.96 \
|
||||
tensorboard \
|
||||
typeguard \
|
||||
dill \
|
||||
onnx \
|
||||
onnxruntime \
|
||||
onnxmltools \
|
||||
multi_quantization \
|
||||
typeguard \
|
||||
numpy \
|
||||
pytest \
|
||||
graphviz
|
||||
|
||||
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
cd /workspace/icefall && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
72
docker/torch1.13.0-cuda11.6.dockerfile
Normal file
72
docker/torch1.13.0-cuda11.6.dockerfile
Normal file
@ -0,0 +1,72 @@
|
||||
FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime
|
||||
|
||||
ENV LC_ALL C.UTF-8
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG K2_VERSION="1.24.3.dev20230725+cuda11.6.torch1.13.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.6.torch1.13.0"
|
||||
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
LABEL k2_version=${K2_VERSION}
|
||||
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
vim \
|
||||
libssl-dev \
|
||||
autoconf \
|
||||
automake \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
gfortran \
|
||||
git \
|
||||
libtool \
|
||||
make \
|
||||
patch \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
valgrind \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||
git+https://github.com/lhotse-speech/lhotse \
|
||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||
\
|
||||
kaldi_native_io \
|
||||
kaldialign \
|
||||
kaldifst \
|
||||
kaldilm \
|
||||
sentencepiece>=0.1.96 \
|
||||
tensorboard \
|
||||
typeguard \
|
||||
dill \
|
||||
onnx \
|
||||
onnxruntime \
|
||||
onnxmltools \
|
||||
multi_quantization \
|
||||
typeguard \
|
||||
numpy \
|
||||
pytest \
|
||||
graphviz
|
||||
|
||||
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
cd /workspace/icefall && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
ENV LD_LIBRARY_PATH /opt/conda/lib/stubs:$LD_LIBRARY_PATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
86
docker/torch1.9.0-cuda10.2.dockerfile
Normal file
86
docker/torch1.9.0-cuda10.2.dockerfile
Normal file
@ -0,0 +1,86 @@
|
||||
FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel
|
||||
|
||||
ENV LC_ALL C.UTF-8
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda10.2.torch1.9.0"
|
||||
ARG TORCHAUDIO_VERSION="0.9.0"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
LABEL k2_version=${K2_VERSION}
|
||||
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||
|
||||
# see https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/
|
||||
|
||||
RUN rm /etc/apt/sources.list.d/cuda.list && \
|
||||
rm /etc/apt/sources.list.d/nvidia-ml.list && \
|
||||
apt-key del 7fa2af80
|
||||
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
vim \
|
||||
libssl-dev \
|
||||
autoconf \
|
||||
automake \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
gfortran \
|
||||
git \
|
||||
libtool \
|
||||
make \
|
||||
patch \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
valgrind \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb && \
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb && \
|
||||
rm -v cuda-keyring_1.0-1_all.deb && \
|
||||
apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install dependencies
|
||||
RUN pip uninstall -y tqdm && \
|
||||
pip install -U --no-cache-dir \
|
||||
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||
git+https://github.com/lhotse-speech/lhotse \
|
||||
\
|
||||
kaldi_native_io \
|
||||
kaldialign \
|
||||
kaldifst \
|
||||
kaldilm \
|
||||
sentencepiece>=0.1.96 \
|
||||
tensorboard \
|
||||
typeguard \
|
||||
dill \
|
||||
onnx \
|
||||
onnxruntime \
|
||||
onnxmltools \
|
||||
multi_quantization \
|
||||
typeguard \
|
||||
numpy \
|
||||
pytest \
|
||||
graphviz \
|
||||
tqdm>=4.63.0
|
||||
|
||||
|
||||
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
cd /workspace/icefall && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
70
docker/torch2.0.0-cuda11.7.dockerfile
Normal file
70
docker/torch2.0.0-cuda11.7.dockerfile
Normal file
@ -0,0 +1,70 @@
|
||||
FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
|
||||
|
||||
ENV LC_ALL C.UTF-8
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG K2_VERSION="1.24.3.dev20230718+cuda11.7.torch2.0.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.7.torch2.0.0"
|
||||
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
LABEL k2_version=${K2_VERSION}
|
||||
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
vim \
|
||||
libssl-dev \
|
||||
autoconf \
|
||||
automake \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
gfortran \
|
||||
git \
|
||||
libtool \
|
||||
make \
|
||||
patch \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
valgrind \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||
git+https://github.com/lhotse-speech/lhotse \
|
||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||
\
|
||||
kaldi_native_io \
|
||||
kaldialign \
|
||||
kaldifst \
|
||||
kaldilm \
|
||||
sentencepiece>=0.1.96 \
|
||||
tensorboard \
|
||||
typeguard \
|
||||
dill \
|
||||
onnx \
|
||||
onnxruntime \
|
||||
onnxmltools \
|
||||
multi_quantization \
|
||||
typeguard \
|
||||
numpy \
|
||||
pytest \
|
||||
graphviz
|
||||
|
||||
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
cd /workspace/icefall && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
@ -86,7 +86,13 @@ rst_epilog = """
|
||||
.. _git-lfs: https://git-lfs.com/
|
||||
.. _ncnn: https://github.com/tencent/ncnn
|
||||
.. _LibriSpeech: https://www.openslr.org/12
|
||||
.. _Gigaspeech: https://github.com/SpeechColab/GigaSpeech
|
||||
.. _musan: http://www.openslr.org/17/
|
||||
.. _ONNX: https://github.com/onnx/onnx
|
||||
.. _onnxruntime: https://github.com/microsoft/onnxruntime
|
||||
.. _torch: https://github.com/pytorch/pytorch
|
||||
.. _torchaudio: https://github.com/pytorch/audio
|
||||
.. _k2: https://github.com/k2-fsa/k2
|
||||
.. _lhotse: https://github.com/lhotse-speech/lhotse
|
||||
.. _yesno: https://www.openslr.org/1/
|
||||
"""
|
||||
|
184
docs/source/decoding-with-langugage-models/LODR.rst
Normal file
184
docs/source/decoding-with-langugage-models/LODR.rst
Normal file
@ -0,0 +1,184 @@
|
||||
.. _LODR:
|
||||
|
||||
LODR for RNN Transducer
|
||||
=======================
|
||||
|
||||
|
||||
As a type of E2E model, neural transducers are usually considered as having an internal
|
||||
language model, which learns the language level information on the training corpus.
|
||||
In real-life scenario, there is often a mismatch between the training corpus and the target corpus space.
|
||||
This mismatch can be a problem when decoding for neural transducer models with language models as its internal
|
||||
language can act "against" the external LM. In this tutorial, we show how to use
|
||||
`Low-order Density Ratio <https://arxiv.org/abs/2203.16776>`_ to alleviate this effect to further improve the performance
|
||||
of langugae model integration.
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the recipe
|
||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||
However, you can easily apply LODR to other recipes.
|
||||
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However,
|
||||
you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models
|
||||
using that corpus.
|
||||
|
||||
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
|
||||
to address the language information mismatch between the training
|
||||
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
|
||||
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{score}\left(y_u|\mathit{x},y\right) =
|
||||
\log p\left(y_u|\mathit{x},y_{1:u-1}\right) +
|
||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||
\lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||
|
||||
|
||||
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
|
||||
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
|
||||
shallow fusion is the subtraction of the source domain LM.
|
||||
|
||||
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
|
||||
considered to be weak and can only capture low-level language information. Therefore, `LODR <https://arxiv.org/abs/2203.16776>`__ proposed to use
|
||||
a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula
|
||||
during decoding for transducer model:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{score}\left(y_u|\mathit{x},y\right) =
|
||||
\log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) +
|
||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||
|
||||
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
|
||||
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
||||
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
||||
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
||||
|
||||
Now, we will show you how to use LODR in ``icefall``.
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
The testing scenario here is intra-domain (we decode the model trained on `LibriSpeech`_ on `LibriSpeech`_ testing sets).
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
Then, we download the external language model and bi-gram LM that are necessary for LODR.
|
||||
Note that the bi-gram is estimated on the LibriSpeech 960 hours' text.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
$
|
||||
$ # download the bi-gram
|
||||
$ git lfs install
|
||||
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
|
||||
$ pushd data/lang_bpe_500
|
||||
$ ln -s ../../librispeech_bigram/2gram.fst.txt .
|
||||
$ popd
|
||||
|
||||
Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.42
|
||||
$ LODR_scale=-0.24
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_LODR \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
||||
--use-shallow-fusion 1 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500 \
|
||||
--tokens-ngram 2 \
|
||||
--ngram-lm-scale $LODR_scale
|
||||
|
||||
There are two extra arguments that need to be given when doing LODR. ``--tokens-ngram`` specifies the order of n-gram. As we
|
||||
are using a bi-gram, we set it to 2. ``--ngram-lm-scale`` is the scale of the bi-gram, it should be a negative number
|
||||
as we are subtracting the bi-gram's score during decoding.
|
||||
|
||||
The decoding results obtained with the above command are shown below:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.61 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 6.74 best for test-other
|
||||
|
||||
Recall that the lowest WER we obtained in :ref:`shallow_fusion` with beam size of 4 is ``2.77/7.08``, LODR
|
||||
indeed **further improves** the WER. We can do even better if we increase ``--beam-size``:
|
||||
|
||||
.. list-table:: WER of LODR with different beam sizes
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.61
|
||||
- 6.74
|
||||
* - 8
|
||||
- 2.45
|
||||
- 6.38
|
||||
* - 12
|
||||
- 2.4
|
||||
- 6.23
|
33
docs/source/decoding-with-langugage-models/index.rst
Normal file
33
docs/source/decoding-with-langugage-models/index.rst
Normal file
@ -0,0 +1,33 @@
|
||||
Decoding with language models
|
||||
=============================
|
||||
|
||||
This section describes how to use external langugage models
|
||||
during decoding to improve the WER of transducer models.
|
||||
|
||||
The following decoding methods with external langugage models are available:
|
||||
|
||||
|
||||
.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean)
|
||||
:widths: 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Decoding method
|
||||
- beam=4
|
||||
* - ``modified_beam_search``
|
||||
- Beam search (i.e. really n-best decoding, the "beam" is the value of n), similar to the original RNN-T paper. Note, this method does not use language model.
|
||||
* - ``modified_beam_search_lm_shallow_fusion``
|
||||
- As ``modified_beam_search``, but interpolate RNN-T scores with language model scores, also known as shallow fusion
|
||||
* - ``modified_beam_search_LODR``
|
||||
- As ``modified_beam_search_lm_shallow_fusion``, but subtract score of a (BPE-symbol-level) bigram backoff language model used as an approximation to the internal language model of RNN-T.
|
||||
* - ``modified_beam_search_lm_rescore``
|
||||
- As ``modified_beam_search``, but rescore the n-best hypotheses with external language model (e.g. RNNLM) and re-rank them.
|
||||
* - ``modified_beam_search_lm_rescore_LODR``
|
||||
- As ``modified_beam_search_lm_rescore``, but also subtract the score of a (BPE-symbol-level) bigram backoff language model during re-ranking.
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
shallow-fusion
|
||||
LODR
|
||||
rescoring
|
252
docs/source/decoding-with-langugage-models/rescoring.rst
Normal file
252
docs/source/decoding-with-langugage-models/rescoring.rst
Normal file
@ -0,0 +1,252 @@
|
||||
.. _rescoring:
|
||||
|
||||
LM rescoring for Transducer
|
||||
=================================
|
||||
|
||||
LM rescoring is a commonly used approach to incorporate external LM information. Unlike shallow-fusion-based
|
||||
methods (see :ref:`shallow_fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search.
|
||||
Rescoring is usually more efficient than shallow fusion since less computation is performed on the external LM.
|
||||
In this tutorial, we will show you how to use external LM to rescore the n-best hypotheses decoded from neural transducer models in
|
||||
`icefall <https://github.com/k2-fsa/icefall>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the recipe
|
||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||
However, you can easily apply shallow fusion to other recipes.
|
||||
If you encounter any problems, please open an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
|
||||
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU for decoding.
|
||||
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
As usual, we first test the model's performance without external LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
Now, we will try to improve the above WER numbers via external LM rescoring. We will download
|
||||
a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
|
||||
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
|
||||
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
|
||||
|
||||
With the RNNLM available, we can rescore the n-best hypotheses generated from `modified_beam_search`. Here,
|
||||
`n` should be the number of beams, i.e ``--beam-size``. The command for LM rescoring is
|
||||
as follows. Note that the ``--decoding-method`` is set to `modified_beam_search_lm_rescore` and ``--use-shallow-fusion``
|
||||
is set to `False`.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.43
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_rescore \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
||||
--use-shallow-fusion 0 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.93 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.6 best for test-other
|
||||
|
||||
Great! We made some improvements! Increasing the size of the n-best hypotheses will further boost the performance,
|
||||
see the following table:
|
||||
|
||||
.. list-table:: WERs of LM rescoring with different beam sizes
|
||||
:widths: 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.93
|
||||
- 7.6
|
||||
* - 8
|
||||
- 2.67
|
||||
- 7.11
|
||||
* - 12
|
||||
- 2.59
|
||||
- 6.86
|
||||
|
||||
In fact, we can also apply LODR (see :ref:`LODR`) when doing LM rescoring. To do so, we need to
|
||||
download the bi-gram required by LODR:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the bi-gram
|
||||
$ git lfs install
|
||||
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
|
||||
$ pushd data/lang_bpe_500
|
||||
$ ln -s ../../librispeech_bigram/2gram.arpa .
|
||||
$ popd
|
||||
|
||||
Then we can performn LM rescoring + LODR by changing the decoding method to `modified_beam_search_lm_rescore_LODR`.
|
||||
|
||||
.. note::
|
||||
|
||||
This decoding method requires the dependency of `kenlm <https://github.com/kpu/kenlm>`_. You can install it
|
||||
via this command: `pip install https://github.com/kpu/kenlm/archive/master.zip`.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.43
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_rescore_LODR \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
||||
--use-shallow-fusion 0 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
You should see the following WERs after executing the commands above:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.9 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.57 best for test-other
|
||||
|
||||
It's slightly better than LM rescoring. If we further increase the beam size, we will see
|
||||
further improvements from LM rescoring + LODR:
|
||||
|
||||
.. list-table:: WERs of LM rescoring + LODR with different beam sizes
|
||||
:widths: 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.9
|
||||
- 7.57
|
||||
* - 8
|
||||
- 2.63
|
||||
- 7.04
|
||||
* - 12
|
||||
- 2.52
|
||||
- 6.73
|
||||
|
||||
As mentioned earlier, LM rescoring is usually faster than shallow-fusion based methods.
|
||||
Here, we benchmark the WERs and decoding speed of them:
|
||||
|
||||
.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean)
|
||||
:widths: 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Decoding method
|
||||
- beam=4
|
||||
- beam=8
|
||||
- beam=12
|
||||
* - ``modified_beam_search``
|
||||
- 3.11/7.93; 132s
|
||||
- 3.1/7.95; 177s
|
||||
- 3.1/7.96; 210s
|
||||
* - ``modified_beam_search_lm_shallow_fusion``
|
||||
- 2.77/7.08; 262s
|
||||
- 2.62/6.65; 352s
|
||||
- 2.58/6.65; 488s
|
||||
* - ``modified_beam_search_LODR``
|
||||
- 2.61/6.74; 400s
|
||||
- 2.45/6.38; 610s
|
||||
- 2.4/6.23; 870s
|
||||
* - ``modified_beam_search_lm_rescore``
|
||||
- 2.93/7.6; 156s
|
||||
- 2.67/7.11; 203s
|
||||
- 2.59/6.86; 255s
|
||||
* - ``modified_beam_search_lm_rescore_LODR``
|
||||
- 2.9/7.57; 160s
|
||||
- 2.63/7.04; 203s
|
||||
- 2.52/6.73; 263s
|
||||
|
||||
.. note::
|
||||
|
||||
Decoding is performed with a single 32G V100, we set ``--max-duration`` to 600.
|
||||
Decoding time here is only for reference and it may vary.
|
176
docs/source/decoding-with-langugage-models/shallow-fusion.rst
Normal file
176
docs/source/decoding-with-langugage-models/shallow-fusion.rst
Normal file
@ -0,0 +1,176 @@
|
||||
.. _shallow_fusion:
|
||||
|
||||
Shallow fusion for Transducer
|
||||
=================================
|
||||
|
||||
External language models (LM) are commonly used to improve WERs for E2E ASR models.
|
||||
This tutorial shows you how to perform ``shallow fusion`` with an external LM
|
||||
to improve the word-error-rate of a transducer model.
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the recipe
|
||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||
However, you can easily apply shallow fusion to other recipes.
|
||||
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
|
||||
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU for decoding.
|
||||
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
These are already good numbers! But we can further improve it by using shallow fusion with external LM.
|
||||
Training a language model usually takes a long time, we can download a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
|
||||
.. note::
|
||||
|
||||
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
|
||||
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
|
||||
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
|
||||
|
||||
To use shallow fusion for decoding, we can execute the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.29
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
||||
--use-shallow-fusion 1 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
Note that we set ``--decoding-method modified_beam_search_lm_shallow_fusion`` and ``--use-shallow-fusion True``
|
||||
to use shallow fusion. ``--lm-type`` specifies the type of neural LM we are going to use, you can either choose
|
||||
between ``rnn`` or ``transformer``. The following three arguments are associated with the rnn:
|
||||
|
||||
- ``--rnn-lm-embedding-dim``
|
||||
The embedding dimension of the RNN LM
|
||||
|
||||
- ``--rnn-lm-hidden-dim``
|
||||
The hidden dimension of the RNN LM
|
||||
|
||||
- ``--rnn-lm-num-layers``
|
||||
The number of RNN layers in the RNN LM.
|
||||
|
||||
|
||||
The decoding result obtained with the above command are shown below.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.77 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.08 best for test-other
|
||||
|
||||
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
|
||||
A few parameters can be tuned to further boost the performance of shallow fusion:
|
||||
|
||||
- ``--lm-scale``
|
||||
|
||||
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
|
||||
the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
|
||||
|
||||
- ``--beam-size``
|
||||
|
||||
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
|
||||
|
||||
Here, we also show how `--beam-size` effect the WER and decoding time:
|
||||
|
||||
.. list-table:: WERs and decoding time (on test-clean) of shallow fusion with different beam sizes
|
||||
:widths: 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
- Decoding time on test-clean (s)
|
||||
* - 4
|
||||
- 2.77
|
||||
- 7.08
|
||||
- 262
|
||||
* - 8
|
||||
- 2.62
|
||||
- 6.65
|
||||
- 352
|
||||
* - 12
|
||||
- 2.58
|
||||
- 6.65
|
||||
- 488
|
||||
|
||||
As we see, a larger beam size during shallow fusion improves the WER, but is also slower.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
BIN
docs/source/docker/img/docker-hub.png
Normal file
BIN
docs/source/docker/img/docker-hub.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 356 KiB |
17
docs/source/docker/index.rst
Normal file
17
docs/source/docker/index.rst
Normal file
@ -0,0 +1,17 @@
|
||||
.. _icefall_docker:
|
||||
|
||||
Docker
|
||||
======
|
||||
|
||||
This section describes how to use pre-built docker images to run `icefall`_.
|
||||
|
||||
.. hint::
|
||||
|
||||
If you only have CPUs available, you can still use the pre-built docker
|
||||
images.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
./intro.rst
|
||||
|
171
docs/source/docker/intro.rst
Normal file
171
docs/source/docker/intro.rst
Normal file
@ -0,0 +1,171 @@
|
||||
Introduction
|
||||
=============
|
||||
|
||||
We have pre-built docker images hosted at the following address:
|
||||
|
||||
`<https://hub.docker.com/repository/docker/k2fsa/icefall/general>`_
|
||||
|
||||
.. figure:: img/docker-hub.png
|
||||
:width: 600
|
||||
:align: center
|
||||
|
||||
You can find the ``Dockerfile`` at `<https://github.com/k2-fsa/icefall/tree/master/docker>`_.
|
||||
|
||||
We describe the following items in this section:
|
||||
|
||||
- How to view available tags
|
||||
- How to download pre-built docker images
|
||||
- How to run the `yesno`_ recipe within a docker container on ``CPU``
|
||||
|
||||
View available tags
|
||||
===================
|
||||
|
||||
You can use the following command to view available tags:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
curl -s 'https://registry.hub.docker.com/v2/repositories/k2fsa/icefall/tags/'|jq '."results"[]["name"]'
|
||||
|
||||
which will give you something like below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
"torch2.0.0-cuda11.7"
|
||||
"torch1.12.1-cuda11.3"
|
||||
"torch1.9.0-cuda10.2"
|
||||
"torch1.13.0-cuda11.6"
|
||||
|
||||
.. hint::
|
||||
|
||||
Available tags will be updated when there are new releases of `torch`_.
|
||||
|
||||
Please select an appropriate combination of `torch`_ and CUDA.
|
||||
|
||||
Download a docker image
|
||||
=======================
|
||||
|
||||
Suppose that you select the tag ``torch1.13.0-cuda11.6``, you can use
|
||||
the following command to download it:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sudo docker image pull k2fsa/icefall:torch1.13.0-cuda11.6
|
||||
|
||||
Run a docker image with GPU
|
||||
===========================
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sudo docker run --gpus all --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
|
||||
|
||||
Run a docker image with CPU
|
||||
===========================
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sudo docker run --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
|
||||
|
||||
Run yesno within a docker container
|
||||
===================================
|
||||
|
||||
After starting the container, the following interface is presented:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
root@60c947eac59c:/workspace/icefall#
|
||||
|
||||
It shows the current user is ``root`` and the current working directory
|
||||
is ``/workspace/icefall``.
|
||||
|
||||
Update the code
|
||||
---------------
|
||||
|
||||
Please first run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
root@60c947eac59c:/workspace/icefall# git pull
|
||||
|
||||
so that your local copy contains the latest code.
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
|
||||
Now we can use
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
root@60c947eac59c:/workspace/icefall# cd egs/yesno/ASR/
|
||||
|
||||
to switch to the ``yesno`` recipe and run
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./prepare.sh
|
||||
|
||||
.. hint::
|
||||
|
||||
If you are running without GPU, it may report the following error:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
File "/opt/conda/lib/python3.9/site-packages/k2/__init__.py", line 23, in <module>
|
||||
from _k2 import DeterminizeWeightPushingType
|
||||
ImportError: libcuda.so.1: cannot open shared object file: No such file or directory
|
||||
|
||||
We can use the following command to fix it:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ln -s /opt/conda/lib/stubs/libcuda.so /opt/conda/lib/stubs/libcuda.so.1
|
||||
|
||||
The logs of running ``./prepare.sh`` are listed below:
|
||||
|
||||
.. literalinclude:: ./log/log-preparation.txt
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
After preparing the data, we can start training with the following command
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/train.py
|
||||
|
||||
All of the training logs are given below:
|
||||
|
||||
.. hint::
|
||||
|
||||
It is running on CPU and it takes only 16 seconds for this run.
|
||||
|
||||
.. literalinclude:: ./log/log-train-2023-08-01-01-55-27
|
||||
|
||||
|
||||
Decoding
|
||||
--------
|
||||
|
||||
After training, we can decode the trained model with
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/decode.py
|
||||
|
||||
The decoding logs are given below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
2023-08-01 02:06:22,400 INFO [decode.py:263] Decoding started
|
||||
2023-08-01 02:06:22,400 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d663.clean', 'torch-version': '1.13.0', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.9', 'icefall-git-branch': 'master', 'icefall-git-sha1': '375520d-clean', 'icefall-git-date': 'Fri Jul 28 07:43:08 2023', 'icefall-path': '/workspace/icefall', 'k2-path': '/opt/conda/lib/python3.9/site-packages/k2/__init__.py', 'lhotse-path': '/opt/conda/lib/python3.9/site-packages/lhotse/__init__.py', 'hostname': '60c947eac59c', 'IP address': '172.17.0.2'}}
|
||||
2023-08-01 02:06:22,401 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-08-01 02:06:22,403 INFO [decode.py:273] device: cpu
|
||||
2023-08-01 02:06:22,406 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
2023-08-01 02:06:22,424 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-08-01 02:06:22,425 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-08-01 02:06:22,504 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
|
||||
[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.
|
||||
2023-08-01 02:06:22,687 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||
2023-08-01 02:06:22,688 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
2023-08-01 02:06:22,690 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||
2023-08-01 02:06:22,690 INFO [decode.py:316] Done!
|
||||
|
||||
Congratulations! You have finished successfully running `icefall`_ within a docker container.
|
@ -21,9 +21,11 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
||||
:caption: Contents:
|
||||
|
||||
installation/index
|
||||
docker/index
|
||||
faqs
|
||||
model-export/index
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@ -34,3 +36,8 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
||||
|
||||
contributing/index
|
||||
huggingface/index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
decoding-with-langugage-models/index
|
||||
|
@ -3,40 +3,28 @@
|
||||
Installation
|
||||
============
|
||||
|
||||
.. hint::
|
||||
|
||||
We also provide :ref:`icefall_docker` support, which has already setup
|
||||
the environment for you.
|
||||
|
||||
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
|
||||
`lhotse <https://github.com/lhotse-speech/lhotse>`_.
|
||||
.. hint::
|
||||
|
||||
We have a colab notebook guiding you step by step to setup the environment.
|
||||
|
||||
|yesno colab notebook|
|
||||
|
||||
.. |yesno colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
|
||||
:target: https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing
|
||||
|
||||
`icefall`_ depends on `k2`_ and `lhotse`_.
|
||||
|
||||
We recommend that you use the following steps to install the dependencies.
|
||||
|
||||
- (0) Install CUDA toolkit and cuDNN
|
||||
- (1) Install PyTorch and torchaudio
|
||||
- (2) Install k2
|
||||
- (3) Install lhotse
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. hint::
|
||||
|
||||
We suggest that you use ``pip install`` to install PyTorch.
|
||||
|
||||
You can use the following command to create a virutal environment in Python:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python3 -m venv ./my_env
|
||||
source ./my_env/bin/activate
|
||||
- (1) Install `torch`_ and `torchaudio`_
|
||||
- (2) Install `k2`_
|
||||
- (3) Install `lhotse`_
|
||||
|
||||
.. caution::
|
||||
|
||||
@ -50,27 +38,20 @@ Please refer to
|
||||
to install CUDA and cuDNN.
|
||||
|
||||
|
||||
(1) Install PyTorch and torchaudio
|
||||
----------------------------------
|
||||
(1) Install torch and torchaudio
|
||||
--------------------------------
|
||||
|
||||
Please refer `<https://pytorch.org/>`_ to install PyTorch
|
||||
and torchaudio.
|
||||
|
||||
.. hint::
|
||||
|
||||
You can also go to `<https://download.pytorch.org/whl/torch_stable.html>`_
|
||||
to download pre-compiled wheels and install them.
|
||||
Please refer `<https://pytorch.org/>`_ to install `torch`_ and `torchaudio`_.
|
||||
|
||||
.. caution::
|
||||
|
||||
Please install torch and torchaudio at the same time.
|
||||
|
||||
|
||||
(2) Install k2
|
||||
--------------
|
||||
|
||||
Please refer to `<https://k2-fsa.github.io/k2/installation/index.html>`_
|
||||
to install ``k2``.
|
||||
to install `k2`_.
|
||||
|
||||
.. caution::
|
||||
|
||||
@ -78,21 +59,18 @@ to install ``k2``.
|
||||
|
||||
.. note::
|
||||
|
||||
We suggest that you install k2 from source by following
|
||||
`<https://k2-fsa.github.io/k2/installation/from_source.html>`_
|
||||
or
|
||||
`<https://k2-fsa.github.io/k2/installation/for_developers.html>`_.
|
||||
We suggest that you install k2 from pre-compiled wheels by following
|
||||
`<https://k2-fsa.github.io/k2/installation/from_wheels.html>`_
|
||||
|
||||
.. hint::
|
||||
|
||||
Please always install the latest version of k2.
|
||||
Please always install the latest version of `k2`_.
|
||||
|
||||
(3) Install lhotse
|
||||
------------------
|
||||
|
||||
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
|
||||
to install ``lhotse``.
|
||||
|
||||
to install `lhotse`_.
|
||||
|
||||
.. hint::
|
||||
|
||||
@ -100,17 +78,16 @@ to install ``lhotse``.
|
||||
|
||||
pip install git+https://github.com/lhotse-speech/lhotse
|
||||
|
||||
to install the latest version of lhotse.
|
||||
to install the latest version of `lhotse`_.
|
||||
|
||||
(4) Download icefall
|
||||
--------------------
|
||||
|
||||
``icefall`` is a collection of Python scripts; what you need is to download it
|
||||
`icefall`_ is a collection of Python scripts; what you need is to download it
|
||||
and set the environment variable ``PYTHONPATH`` to point to it.
|
||||
|
||||
Assume you want to place ``icefall`` in the folder ``/tmp``. The
|
||||
following commands show you how to setup ``icefall``:
|
||||
|
||||
Assume you want to place `icefall`_ in the folder ``/tmp``. The
|
||||
following commands show you how to setup `icefall`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@ -122,285 +99,334 @@ following commands show you how to setup ``icefall``:
|
||||
|
||||
.. HINT::
|
||||
|
||||
You can put several versions of ``icefall`` in the same virtual environment.
|
||||
To switch among different versions of ``icefall``, just set ``PYTHONPATH``
|
||||
You can put several versions of `icefall`_ in the same virtual environment.
|
||||
To switch among different versions of `icefall`_, just set ``PYTHONPATH``
|
||||
to point to the version you want.
|
||||
|
||||
|
||||
Installation example
|
||||
--------------------
|
||||
|
||||
The following shows an example about setting up the environment.
|
||||
|
||||
|
||||
(1) Create a virtual environment
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ virtualenv -p python3.8 test-icefall
|
||||
kuangfangjun:~$ virtualenv -p python3.8 test-icefall
|
||||
created virtual environment CPython3.8.0.final.0-64 in 9422ms
|
||||
creator CPython3Posix(dest=/star-fj/fangjun/test-icefall, clear=False, no_vcs_ignore=False, global=False)
|
||||
seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/star-fj/fangjun/.local/share/virtualenv)
|
||||
added seed packages: pip==22.3.1, setuptools==65.6.3, wheel==0.38.4
|
||||
activators BashActivator,CShellActivator,FishActivator,NushellActivator,PowerShellActivator,PythonActivator
|
||||
|
||||
created virtual environment CPython3.8.6.final.0-64 in 1540ms
|
||||
creator CPython3Posix(dest=/ceph-fj/fangjun/test-icefall, clear=False, no_vcs_ignore=False, global=False)
|
||||
seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/fangjun/.local/share/v
|
||||
irtualenv)
|
||||
added seed packages: pip==21.1.3, setuptools==57.4.0, wheel==0.36.2
|
||||
activators BashActivator,CShellActivator,FishActivator,PowerShellActivator,PythonActivator,XonshActivator
|
||||
kuangfangjun:~$ source test-icefall/bin/activate
|
||||
|
||||
(test-icefall) kuangfangjun:~$
|
||||
|
||||
(2) Activate your virtual environment
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
(2) Install CUDA toolkit and cuDNN
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You need to determine the version of CUDA toolkit to install.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ source test-icefall/bin/activate
|
||||
(test-icefall) kuangfangjun:~$ nvidia-smi | head -n 4
|
||||
|
||||
(3) Install k2
|
||||
Wed Jul 26 21:57:49 2023
|
||||
+-----------------------------------------------------------------------------+
|
||||
| NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 |
|
||||
|-------------------------------+----------------------+----------------------+
|
||||
|
||||
You can choose any CUDA version that is ``not`` greater than the version printed by ``nvidia-smi``.
|
||||
In our case, we can choose any version ``<= 11.6``.
|
||||
|
||||
We will use ``CUDA 11.6`` in this example. Please follow
|
||||
`<https://k2-fsa.github.io/k2/installation/cuda-cudnn.html#cuda-11-6>`_
|
||||
to install CUDA toolkit and cuDNN if you have not done that before.
|
||||
|
||||
After installing CUDA toolkit, you can use the following command to verify it:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
(test-icefall) kuangfangjun:~$ nvcc --version
|
||||
|
||||
nvcc: NVIDIA (R) Cuda compiler driver
|
||||
Copyright (c) 2005-2019 NVIDIA Corporation
|
||||
Built on Wed_Oct_23_19:24:38_PDT_2019
|
||||
Cuda compilation tools, release 10.2, V10.2.89
|
||||
|
||||
(3) Install torch and torchaudio
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Since we have selected CUDA toolkit ``11.6``, we have to install a version of `torch`_
|
||||
that is compiled against CUDA ``11.6``. We select ``torch 1.13.0+cu116`` in this
|
||||
example.
|
||||
|
||||
After selecting the version of `torch`_ to install, we need to also install
|
||||
a compatible version of `torchaudio`_, which is ``0.13.0+cu116`` in our case.
|
||||
|
||||
Please refer to `<https://pytorch.org/audio/stable/installation.html#compatibility-matrix>`_
|
||||
to select an appropriate version of `torchaudio`_ to install if you use a different
|
||||
version of `torch`_.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
(test-icefall) kuangfangjun:~$ pip install torch==1.13.0+cu116 torchaudio==0.13.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
Looking in links: https://download.pytorch.org/whl/torch_stable.html
|
||||
Collecting torch==1.13.0+cu116
|
||||
Downloading https://download.pytorch.org/whl/cu116/torch-1.13.0%2Bcu116-cp38-cp38-linux_x86_64.whl (1983.0 MB)
|
||||
________________________________________ 2.0/2.0 GB 764.4 kB/s eta 0:00:00
|
||||
Collecting torchaudio==0.13.0+cu116
|
||||
Downloading https://download.pytorch.org/whl/cu116/torchaudio-0.13.0%2Bcu116-cp38-cp38-linux_x86_64.whl (4.2 MB)
|
||||
________________________________________ 4.2/4.2 MB 1.3 MB/s eta 0:00:00
|
||||
Requirement already satisfied: typing-extensions in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from torch==1.13.0+cu116) (4.7.1)
|
||||
Installing collected packages: torch, torchaudio
|
||||
Successfully installed torch-1.13.0+cu116 torchaudio-0.13.0+cu116
|
||||
|
||||
Verify that `torch`_ and `torchaudio`_ are successfully installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
(test-icefall) kuangfangjun:~$ python3 -c "import torch; print(torch.__version__)"
|
||||
|
||||
1.13.0+cu116
|
||||
|
||||
(test-icefall) kuangfangjun:~$ python3 -c "import torchaudio; print(torchaudio.__version__)"
|
||||
|
||||
0.13.0+cu116
|
||||
|
||||
(4) Install k2
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
We will install `k2`_ from pre-compiled wheels by following
|
||||
`<https://k2-fsa.github.io/k2/installation/from_wheels.html>`_
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install k2==1.4.dev20210822+cpu.torch1.9.0 -f https://k2-fsa.org/nightly/index.html
|
||||
(test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
|
||||
Looking in links: https://k2-fsa.org/nightly/index.html
|
||||
Collecting k2==1.4.dev20210822+cpu.torch1.9.0
|
||||
Downloading https://k2-fsa.org/nightly/whl/k2-1.4.dev20210822%2Bcpu.torch1.9.0-cp38-cp38-linux_x86_64.whl (1.6 MB)
|
||||
|________________________________| 1.6 MB 185 kB/s
|
||||
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
Looking in links: https://k2-fsa.github.io/k2/cuda.html
|
||||
Collecting k2==1.24.3.dev20230725+cuda11.6.torch1.13.0
|
||||
Downloading https://huggingface.co/csukuangfj/k2/resolve/main/ubuntu-cuda/k2-1.24.3.dev20230725%2Bcuda11.6.torch1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (104.3 MB)
|
||||
________________________________________ 104.3/104.3 MB 5.1 MB/s eta 0:00:00
|
||||
Requirement already satisfied: torch==1.13.0 in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from k2==1.24.3.dev20230725+cuda11.6.torch1.13.0) (1.13.0+cu116)
|
||||
Collecting graphviz
|
||||
Downloading graphviz-0.17-py3-none-any.whl (18 kB)
|
||||
Collecting torch==1.9.0
|
||||
Using cached torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl (831.4 MB)
|
||||
Collecting typing-extensions
|
||||
Using cached typing_extensions-3.10.0.0-py3-none-any.whl (26 kB)
|
||||
Installing collected packages: typing-extensions, torch, graphviz, k2
|
||||
Successfully installed graphviz-0.17 k2-1.4.dev20210822+cpu.torch1.9.0 torch-1.9.0 typing-extensions-3.10.0.0
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/de/5e/fcbb22c68208d39edff467809d06c9d81d7d27426460ebc598e55130c1aa/graphviz-0.20.1-py3-none-any.whl (47 kB)
|
||||
Requirement already satisfied: typing-extensions in /star-fj/fangjun/test-icefall/lib/python3.8/site-packages (from torch==1.13.0->k2==1.24.3.dev20230725+cuda11.6.torch1.13.0) (4.7.1)
|
||||
Installing collected packages: graphviz, k2
|
||||
Successfully installed graphviz-0.20.1 k2-1.24.3.dev20230725+cuda11.6.torch1.13.0
|
||||
|
||||
.. WARNING::
|
||||
.. hint::
|
||||
|
||||
We choose to install a CPU version of k2 for testing. You would probably want to install
|
||||
a CUDA version of k2.
|
||||
Please refer to `<https://k2-fsa.github.io/k2/cuda.html>`_ for the available
|
||||
pre-compiled wheels about `k2`_.
|
||||
|
||||
Verify that `k2`_ has been installed successfully:
|
||||
|
||||
(4) Install lhotse
|
||||
.. code-block:: bash
|
||||
|
||||
(test-icefall) kuangfangjun:~$ python3 -m k2.version
|
||||
|
||||
Collecting environment information...
|
||||
|
||||
k2 version: 1.24.3
|
||||
Build type: Release
|
||||
Git SHA1: 4c05309499a08454997adf500b56dcc629e35ae5
|
||||
Git date: Tue Jul 25 16:23:36 2023
|
||||
Cuda used to build k2: 11.6
|
||||
cuDNN used to build k2: 8.3.2
|
||||
Python version used to build k2: 3.8
|
||||
OS used to build k2: CentOS Linux release 7.9.2009 (Core)
|
||||
CMake version: 3.27.0
|
||||
GCC version: 9.3.1
|
||||
CMAKE_CUDA_FLAGS: -Wno-deprecated-gpu-targets -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_50,code=sm_50 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_61,code=sm_61 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_70,code=sm_70 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_80,code=sm_80 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_86,code=sm_86 -DONNX_NAMESPACE=onnx_c2 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_86,code=compute_86 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall --compiler-options -Wno-strict-overflow --compiler-options -Wno-unknown-pragmas
|
||||
CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable -Wno-strict-overflow
|
||||
PyTorch version used to build k2: 1.13.0+cu116
|
||||
PyTorch is using Cuda: 11.6
|
||||
NVTX enabled: True
|
||||
With CUDA: True
|
||||
Disable debug: True
|
||||
Sync kernels : False
|
||||
Disable checks: False
|
||||
Max cpu memory allocate: 214748364800 bytes (or 200.0 GB)
|
||||
k2 abort: False
|
||||
__file__: /star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/version/version.py
|
||||
_k2.__file__: /star-fj/fangjun/test-icefall/lib/python3.8/site-packages/_k2.cpython-38-x86_64-linux-gnu.so
|
||||
|
||||
(5) Install lhotse
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block::
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install git+https://github.com/lhotse-speech/lhotse
|
||||
(test-icefall) kuangfangjun:~$ pip install git+https://github.com/lhotse-speech/lhotse
|
||||
|
||||
Collecting git+https://github.com/lhotse-speech/lhotse
|
||||
Cloning https://github.com/lhotse-speech/lhotse to /tmp/pip-req-build-7b1b76ge
|
||||
Running command git clone -q https://github.com/lhotse-speech/lhotse /tmp/pip-req-build-7b1b76ge
|
||||
Collecting audioread>=2.1.9
|
||||
Using cached audioread-2.1.9-py3-none-any.whl
|
||||
Collecting SoundFile>=0.10
|
||||
Using cached SoundFile-0.10.3.post1-py2.py3-none-any.whl (21 kB)
|
||||
Collecting click>=7.1.1
|
||||
Using cached click-8.0.1-py3-none-any.whl (97 kB)
|
||||
Cloning https://github.com/lhotse-speech/lhotse to /tmp/pip-req-build-vq12fd5i
|
||||
Running command git clone --filter=blob:none --quiet https://github.com/lhotse-speech/lhotse /tmp/pip-req-build-vq12fd5i
|
||||
Resolved https://github.com/lhotse-speech/lhotse to commit 7640d663469b22cd0b36f3246ee9b849cd25e3b7
|
||||
Installing build dependencies ... done
|
||||
Getting requirements to build wheel ... done
|
||||
Preparing metadata (pyproject.toml) ... done
|
||||
Collecting cytoolz>=0.10.1
|
||||
Using cached cytoolz-0.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
|
||||
Collecting dataclasses
|
||||
Using cached dataclasses-0.6-py3-none-any.whl (14 kB)
|
||||
Collecting h5py>=2.10.0
|
||||
Downloading h5py-3.4.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.5 MB)
|
||||
|________________________________| 4.5 MB 684 kB/s
|
||||
Collecting intervaltree>=3.1.0
|
||||
Using cached intervaltree-3.1.0-py2.py3-none-any.whl
|
||||
Collecting lilcom>=1.1.0
|
||||
Using cached lilcom-1.1.1-cp38-cp38-linux_x86_64.whl
|
||||
Collecting numpy>=1.18.1
|
||||
Using cached numpy-1.21.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.8 MB)
|
||||
Collecting packaging
|
||||
Using cached packaging-21.0-py3-none-any.whl (40 kB)
|
||||
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1e/3b/a7828d575aa17fb7acaf1ced49a3655aa36dad7e16eb7e6a2e4df0dda76f/cytoolz-0.12.2-cp38-cp38-
|
||||
manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
|
||||
________________________________________ 2.0/2.0 MB 33.2 MB/s eta 0:00:00
|
||||
Collecting pyyaml>=5.3.1
|
||||
Using cached PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl (662 kB)
|
||||
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c8/6b/6600ac24725c7388255b2f5add93f91e58a5d7efaf4af244fdbcc11a541b/PyYAML-6.0.1-cp38-cp38-ma
|
||||
nylinux_2_17_x86_64.manylinux2014_x86_64.whl (736 kB)
|
||||
________________________________________ 736.6/736.6 kB 38.6 MB/s eta 0:00:00
|
||||
Collecting dataclasses
|
||||
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/26/2f/1095cdc2868052dd1e64520f7c0d5c8c550ad297e944e641dbf1ffbb9a5d/dataclasses-0.6-py3-none-
|
||||
any.whl (14 kB)
|
||||
Requirement already satisfied: torchaudio in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (0.13.0+cu116)
|
||||
Collecting lilcom>=1.1.0
|
||||
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a8/65/df0a69c52bd085ca1ad4e5c4c1a5c680e25f9477d8e49316c4ff1e5084a4/lilcom-1.7-cp38-cp38-many
|
||||
linux_2_17_x86_64.manylinux2014_x86_64.whl (87 kB)
|
||||
________________________________________ 87.1/87.1 kB 8.7 MB/s eta 0:00:00
|
||||
Collecting tqdm
|
||||
Downloading tqdm-4.62.1-py2.py3-none-any.whl (76 kB)
|
||||
|________________________________| 76 kB 2.7 MB/s
|
||||
Collecting torchaudio==0.9.0
|
||||
Downloading torchaudio-0.9.0-cp38-cp38-manylinux1_x86_64.whl (1.9 MB)
|
||||
|________________________________| 1.9 MB 73.1 MB/s
|
||||
Requirement already satisfied: torch==1.9.0 in ./test-icefall/lib/python3.8/site-packages (from torchaudio==0.9.0->lhotse===0.8.0.dev
|
||||
-2a1410b-clean) (1.9.0)
|
||||
Requirement already satisfied: typing-extensions in ./test-icefall/lib/python3.8/site-packages (from torch==1.9.0->torchaudio==0.9.0-
|
||||
>lhotse===0.8.0.dev-2a1410b-clean) (3.10.0.0)
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/e6/02/a2cff6306177ae6bc73bc0665065de51dfb3b9db7373e122e2735faf0d97/tqdm-4.65.0-py3-none-any
|
||||
.whl (77 kB)
|
||||
Requirement already satisfied: numpy>=1.18.1 in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (1.24.4)
|
||||
Collecting audioread>=2.1.9
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/5d/cb/82a002441902dccbe427406785db07af10182245ee639ea9f4d92907c923/audioread-3.0.0.tar.gz (
|
||||
377 kB)
|
||||
Preparing metadata (setup.py) ... done
|
||||
Collecting tabulate>=0.8.1
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-
|
||||
any.whl (35 kB)
|
||||
Collecting click>=7.1.1
|
||||
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1a/70/e63223f8116931d365993d4a6b7ef653a4d920b41d03de7c59499962821f/click-8.1.6-py3-none-any.
|
||||
whl (97 kB)
|
||||
________________________________________ 97.9/97.9 kB 8.4 MB/s eta 0:00:00
|
||||
Collecting packaging
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/ab/c3/57f0601a2d4fe15de7a553c00adbc901425661bf048f2a22dfc500caf121/packaging-23.1-py3-none-
|
||||
any.whl (48 kB)
|
||||
Collecting intervaltree>=3.1.0
|
||||
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/50/fb/396d568039d21344639db96d940d40eb62befe704ef849b27949ded5c3bb/intervaltree-3.1.0.tar.gz
|
||||
(32 kB)
|
||||
Preparing metadata (setup.py) ... done
|
||||
Requirement already satisfied: torch in ./test-icefall/lib/python3.8/site-packages (from lhotse==1.16.0.dev0+git.7640d66.clean) (1.13.0+cu116)
|
||||
Collecting SoundFile>=0.10
|
||||
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ad/bd/0602167a213d9184fc688b1086dc6d374b7ae8c33eccf169f9b50ce6568c/soundfile-0.12.1-py2.py3-
|
||||
none-manylinux_2_17_x86_64.whl (1.3 MB)
|
||||
________________________________________ 1.3/1.3 MB 46.5 MB/s eta 0:00:00
|
||||
Collecting toolz>=0.8.0
|
||||
Using cached toolz-0.11.1-py3-none-any.whl (55 kB)
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/7f/5c/922a3508f5bda2892be3df86c74f9cf1e01217c2b1f8a0ac4841d903e3e9/toolz-0.12.0-py3-none-any.whl (55 kB)
|
||||
Collecting sortedcontainers<3.0,>=2.0
|
||||
Using cached sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
|
||||
Collecting cffi>=1.0
|
||||
Using cached cffi-1.14.6-cp38-cp38-manylinux1_x86_64.whl (411 kB)
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/b7/8b/06f30caa03b5b3ac006de4f93478dbd0239e2a16566d81a106c322dc4f79/cffi-1.15.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (442 kB)
|
||||
Requirement already satisfied: typing-extensions in ./test-icefall/lib/python3.8/site-packages (from torch->lhotse==1.16.0.dev0+git.7640d66.clean) (4.7.1)
|
||||
Collecting pycparser
|
||||
Using cached pycparser-2.20-py2.py3-none-any.whl (112 kB)
|
||||
Collecting pyparsing>=2.0.2
|
||||
Using cached pyparsing-2.4.7-py2.py3-none-any.whl (67 kB)
|
||||
Building wheels for collected packages: lhotse
|
||||
Building wheel for lhotse (setup.py) ... done
|
||||
Created wheel for lhotse: filename=lhotse-0.8.0.dev_2a1410b_clean-py3-none-any.whl size=342242 sha256=f683444afa4dc0881133206b4646a
|
||||
9d0f774224cc84000f55d0a67f6e4a37997
|
||||
Stored in directory: /tmp/pip-ephem-wheel-cache-ftu0qysz/wheels/7f/7a/8e/a0bf241336e2e3cb573e1e21e5600952d49f5162454f2e612f
|
||||
WARNING: Built wheel for lhotse is invalid: Metadata 1.2 mandates PEP 440 version, but '0.8.0.dev-2a1410b-clean' is not
|
||||
Failed to build lhotse
|
||||
Installing collected packages: pycparser, toolz, sortedcontainers, pyparsing, numpy, cffi, tqdm, torchaudio, SoundFile, pyyaml, packa
|
||||
ging, lilcom, intervaltree, h5py, dataclasses, cytoolz, click, audioread, lhotse
|
||||
Running setup.py install for lhotse ... done
|
||||
DEPRECATION: lhotse was installed using the legacy 'setup.py install' method, because a wheel could not be built for it. A possible
|
||||
replacement is to fix the wheel build issue reported above. You can find discussion regarding this at https://github.com/pypa/pip/is
|
||||
sues/8368.
|
||||
Successfully installed SoundFile-0.10.3.post1 audioread-2.1.9 cffi-1.14.6 click-8.0.1 cytoolz-0.11.0 dataclasses-0.6 h5py-3.4.0 inter
|
||||
valtree-3.1.0 lhotse-0.8.0.dev-2a1410b-clean lilcom-1.1.1 numpy-1.21.2 packaging-21.0 pycparser-2.20 pyparsing-2.4.7 pyyaml-5.4.1 sor
|
||||
tedcontainers-2.4.0 toolz-0.11.1 torchaudio-0.9.0 tqdm-4.62.1
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/62/d5/5f610ebe421e85889f2e55e33b7f9a6795bd982198517d912eb1c76e1a53/pycparser-2.21-py2.py3-none-any.whl (118 kB)
|
||||
Building wheels for collected packages: lhotse, audioread, intervaltree
|
||||
Building wheel for lhotse (pyproject.toml) ... done
|
||||
Created wheel for lhotse: filename=lhotse-1.16.0.dev0+git.7640d66.clean-py3-none-any.whl size=687627 sha256=cbf0a4d2d0b639b33b91637a4175bc251d6a021a069644ecb1a9f2b3a83d072a
|
||||
Stored in directory: /tmp/pip-ephem-wheel-cache-wwtk90_m/wheels/7f/7a/8e/a0bf241336e2e3cb573e1e21e5600952d49f5162454f2e612f
|
||||
Building wheel for audioread (setup.py) ... done
|
||||
Created wheel for audioread: filename=audioread-3.0.0-py3-none-any.whl size=23704 sha256=5e2d3537c96ce9cf0f645a654c671163707bf8cb8d9e358d0e2b0939a85ff4c2
|
||||
Stored in directory: /star-fj/fangjun/.cache/pip/wheels/e2/c3/9c/f19ae5a03f8862d9f0776b0c0570f1fdd60a119d90954e3f39
|
||||
Building wheel for intervaltree (setup.py) ... done
|
||||
Created wheel for intervaltree: filename=intervaltree-3.1.0-py2.py3-none-any.whl size=26098 sha256=2604170976cfffe0d2f678cb1a6e5b525f561cd50babe53d631a186734fec9f9
|
||||
Stored in directory: /star-fj/fangjun/.cache/pip/wheels/f3/ed/2b/c179ebfad4e15452d6baef59737f27beb9bfb442e0620f7271
|
||||
Successfully built lhotse audioread intervaltree
|
||||
Installing collected packages: sortedcontainers, dataclasses, tqdm, toolz, tabulate, pyyaml, pycparser, packaging, lilcom, intervaltree, click, audioread, cytoolz, cffi, SoundFile, lhotse
|
||||
Successfully installed SoundFile-0.12.1 audioread-3.0.0 cffi-1.15.1 click-8.1.6 cytoolz-0.12.2 dataclasses-0.6 intervaltree-3.1.0 lhotse-1.16.0.dev0+git.7640d66.clean lilcom-1.7 packaging-23.1 pycparser-2.21 pyyaml-6.0.1 sortedcontainers-2.4.0 tabulate-0.9.0 toolz-0.12.0 tqdm-4.65.0
|
||||
|
||||
(5) Download icefall
|
||||
|
||||
Verify that `lhotse`_ has been installed successfully:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
(test-icefall) kuangfangjun:~$ python3 -c "import lhotse; print(lhotse.__version__)"
|
||||
|
||||
1.16.0.dev+git.7640d66.clean
|
||||
|
||||
(6) Download icefall
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block::
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd /tmp
|
||||
$ git clone https://github.com/k2-fsa/icefall
|
||||
(test-icefall) kuangfangjun:~$ cd /tmp/
|
||||
|
||||
(test-icefall) kuangfangjun:tmp$ git clone https://github.com/k2-fsa/icefall
|
||||
|
||||
Cloning into 'icefall'...
|
||||
remote: Enumerating objects: 500, done.
|
||||
remote: Counting objects: 100% (500/500), done.
|
||||
remote: Compressing objects: 100% (308/308), done.
|
||||
remote: Total 500 (delta 263), reused 307 (delta 102), pack-reused 0
|
||||
Receiving objects: 100% (500/500), 172.49 KiB | 385.00 KiB/s, done.
|
||||
Resolving deltas: 100% (263/263), done.
|
||||
remote: Enumerating objects: 12942, done.
|
||||
remote: Counting objects: 100% (67/67), done.
|
||||
remote: Compressing objects: 100% (56/56), done.
|
||||
remote: Total 12942 (delta 17), reused 35 (delta 6), pack-reused 12875
|
||||
Receiving objects: 100% (12942/12942), 14.77 MiB | 9.29 MiB/s, done.
|
||||
Resolving deltas: 100% (8835/8835), done.
|
||||
|
||||
$ cd icefall
|
||||
$ pip install -r requirements.txt
|
||||
|
||||
Collecting kaldilm
|
||||
Downloading kaldilm-1.8.tar.gz (48 kB)
|
||||
|________________________________| 48 kB 574 kB/s
|
||||
Collecting kaldialign
|
||||
Using cached kaldialign-0.2-cp38-cp38-linux_x86_64.whl
|
||||
Collecting sentencepiece>=0.1.96
|
||||
Using cached sentencepiece-0.1.96-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
|
||||
Collecting tensorboard
|
||||
Using cached tensorboard-2.6.0-py3-none-any.whl (5.6 MB)
|
||||
Requirement already satisfied: setuptools>=41.0.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r
|
||||
requirements.txt (line 4)) (57.4.0)
|
||||
Collecting absl-py>=0.4
|
||||
Using cached absl_py-0.13.0-py3-none-any.whl (132 kB)
|
||||
Collecting google-auth-oauthlib<0.5,>=0.4.1
|
||||
Using cached google_auth_oauthlib-0.4.5-py2.py3-none-any.whl (18 kB)
|
||||
Collecting grpcio>=1.24.3
|
||||
Using cached grpcio-1.39.0-cp38-cp38-manylinux2014_x86_64.whl (4.3 MB)
|
||||
Requirement already satisfied: wheel>=0.26 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r require
|
||||
ments.txt (line 4)) (0.36.2)
|
||||
Requirement already satisfied: numpy>=1.12.0 in /ceph-fj/fangjun/test-icefall/lib/python3.8/site-packages (from tensorboard->-r requi
|
||||
rements.txt (line 4)) (1.21.2)
|
||||
Collecting protobuf>=3.6.0
|
||||
Using cached protobuf-3.17.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
|
||||
Collecting werkzeug>=0.11.15
|
||||
Using cached Werkzeug-2.0.1-py3-none-any.whl (288 kB)
|
||||
Collecting tensorboard-data-server<0.7.0,>=0.6.0
|
||||
Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB)
|
||||
Collecting google-auth<2,>=1.6.3
|
||||
Downloading google_auth-1.35.0-py2.py3-none-any.whl (152 kB)
|
||||
|________________________________| 152 kB 1.4 MB/s
|
||||
Collecting requests<3,>=2.21.0
|
||||
Using cached requests-2.26.0-py2.py3-none-any.whl (62 kB)
|
||||
Collecting tensorboard-plugin-wit>=1.6.0
|
||||
Using cached tensorboard_plugin_wit-1.8.0-py3-none-any.whl (781 kB)
|
||||
Collecting markdown>=2.6.8
|
||||
Using cached Markdown-3.3.4-py3-none-any.whl (97 kB)
|
||||
Collecting six
|
||||
Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
|
||||
Collecting cachetools<5.0,>=2.0.0
|
||||
Using cached cachetools-4.2.2-py3-none-any.whl (11 kB)
|
||||
Collecting rsa<5,>=3.1.4
|
||||
Using cached rsa-4.7.2-py3-none-any.whl (34 kB)
|
||||
Collecting pyasn1-modules>=0.2.1
|
||||
Using cached pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB)
|
||||
Collecting requests-oauthlib>=0.7.0
|
||||
Using cached requests_oauthlib-1.3.0-py2.py3-none-any.whl (23 kB)
|
||||
Collecting pyasn1<0.5.0,>=0.4.6
|
||||
Using cached pyasn1-0.4.8-py2.py3-none-any.whl (77 kB)
|
||||
Collecting urllib3<1.27,>=1.21.1
|
||||
Using cached urllib3-1.26.6-py2.py3-none-any.whl (138 kB)
|
||||
Collecting certifi>=2017.4.17
|
||||
Using cached certifi-2021.5.30-py2.py3-none-any.whl (145 kB)
|
||||
Collecting charset-normalizer~=2.0.0
|
||||
Using cached charset_normalizer-2.0.4-py3-none-any.whl (36 kB)
|
||||
Collecting idna<4,>=2.5
|
||||
Using cached idna-3.2-py3-none-any.whl (59 kB)
|
||||
Collecting oauthlib>=3.0.0
|
||||
Using cached oauthlib-3.1.1-py2.py3-none-any.whl (146 kB)
|
||||
Building wheels for collected packages: kaldilm
|
||||
Building wheel for kaldilm (setup.py) ... done
|
||||
Created wheel for kaldilm: filename=kaldilm-1.8-cp38-cp38-linux_x86_64.whl size=897233 sha256=eccb906cafcd45bf9a7e1a1718e4534254bfb
|
||||
f4c0d0cbc66eee6c88d68a63862
|
||||
Stored in directory: /root/fangjun/.cache/pip/wheels/85/7d/63/f2dd586369b8797cb36d213bf3a84a789eeb92db93d2e723c9
|
||||
Successfully built kaldilm
|
||||
Installing collected packages: urllib3, pyasn1, idna, charset-normalizer, certifi, six, rsa, requests, pyasn1-modules, oauthlib, cach
|
||||
etools, requests-oauthlib, google-auth, werkzeug, tensorboard-plugin-wit, tensorboard-data-server, protobuf, markdown, grpcio, google
|
||||
-auth-oauthlib, absl-py, tensorboard, sentencepiece, kaldilm, kaldialign
|
||||
Successfully installed absl-py-0.13.0 cachetools-4.2.2 certifi-2021.5.30 charset-normalizer-2.0.4 google-auth-1.35.0 google-auth-oaut
|
||||
hlib-0.4.5 grpcio-1.39.0 idna-3.2 kaldialign-0.2 kaldilm-1.8 markdown-3.3.4 oauthlib-3.1.1 protobuf-3.17.3 pyasn1-0.4.8 pyasn1-module
|
||||
s-0.2.8 requests-2.26.0 requests-oauthlib-1.3.0 rsa-4.7.2 sentencepiece-0.1.96 six-1.16.0 tensorboard-2.6.0 tensorboard-data-server-0
|
||||
.6.1 tensorboard-plugin-wit-1.8.0 urllib3-1.26.6 werkzeug-2.0.1
|
||||
(test-icefall) kuangfangjun:tmp$ cd icefall/
|
||||
|
||||
(test-icefall) kuangfangjun:icefall$ pip install -r ./requirements.txt
|
||||
|
||||
Test Your Installation
|
||||
----------------------
|
||||
|
||||
To test that your installation is successful, let us run
|
||||
the `yesno recipe <https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR>`_
|
||||
on CPU.
|
||||
on ``CPU``.
|
||||
|
||||
Data preparation
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
$ cd /tmp/icefall
|
||||
$ cd egs/yesno/ASR
|
||||
$ ./prepare.sh
|
||||
(test-icefall) kuangfangjun:icefall$ export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||
|
||||
(test-icefall) kuangfangjun:icefall$ cd /tmp/icefall
|
||||
|
||||
(test-icefall) kuangfangjun:icefall$ cd egs/yesno/ASR
|
||||
|
||||
(test-icefall) kuangfangjun:ASR$ ./prepare.sh
|
||||
|
||||
|
||||
The log of running ``./prepare.sh`` is:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2023-05-12 17:55:21 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
|
||||
2023-05-12 17:55:21 (prepare.sh:30:main) Stage 0: Download data
|
||||
/tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|_______________________________________________________________| 4.70M/4.70M [06:54<00:00, 11.4kB/s]
|
||||
2023-05-12 18:02:19 (prepare.sh:39:main) Stage 1: Prepare yesno manifest
|
||||
2023-05-12 18:02:21 (prepare.sh:45:main) Stage 2: Compute fbank for yesno
|
||||
2023-05-12 18:02:23,199 INFO [compute_fbank_yesno.py:65] Processing train
|
||||
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:00<00:00, 212.60it/s]
|
||||
2023-05-12 18:02:23,640 INFO [compute_fbank_yesno.py:65] Processing test
|
||||
Extracting and storing features: 100%|_______________________________________________________________| 30/30 [00:00<00:00, 304.53it/s]
|
||||
2023-05-12 18:02:24 (prepare.sh:51:main) Stage 3: Prepare lang
|
||||
2023-05-12 18:02:26 (prepare.sh:66:main) Stage 4: Prepare G
|
||||
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79
|
||||
[I] Reading \data\ section.
|
||||
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140
|
||||
[I] Reading \1-grams: section.
|
||||
2023-05-12 18:02:26 (prepare.sh:92:main) Stage 5: Compile HLG
|
||||
2023-05-12 18:02:28,581 INFO [compile_hlg.py:124] Processing data/lang_phone
|
||||
2023-05-12 18:02:28,582 INFO [lexicon.py:171] Converting L.pt to Linv.pt
|
||||
2023-05-12 18:02:28,609 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
|
||||
2023-05-12 18:02:28,610 INFO [compile_hlg.py:52] Loading G.fst.txt
|
||||
2023-05-12 18:02:28,611 INFO [compile_hlg.py:62] Intersecting L and G
|
||||
2023-05-12 18:02:28,613 INFO [compile_hlg.py:64] LG shape: (4, None)
|
||||
2023-05-12 18:02:28,613 INFO [compile_hlg.py:66] Connecting LG
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:71] Determinizing LG
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:74] <class '_k2.ragged.RaggedTensor'>
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
|
||||
2023-05-12 18:02:28,616 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None)
|
||||
2023-05-12 18:02:28,617 INFO [compile_hlg.py:96] Arc sorting LG
|
||||
2023-05-12 18:02:28,617 INFO [compile_hlg.py:99] Composing H and LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:106] Connecting LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:109] Arc sorting LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:111] HLG.shape: (8, None)
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone
|
||||
|
||||
2023-07-27 12:41:39 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
|
||||
2023-07-27 12:41:39 (prepare.sh:30:main) Stage 0: Download data
|
||||
/tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|___________________________________________________| 4.70M/4.70M [00:00<00:00, 11.1MB/s]
|
||||
2023-07-27 12:41:46 (prepare.sh:39:main) Stage 1: Prepare yesno manifest
|
||||
2023-07-27 12:41:50 (prepare.sh:45:main) Stage 2: Compute fbank for yesno
|
||||
2023-07-27 12:41:55,718 INFO [compute_fbank_yesno.py:65] Processing train
|
||||
Extracting and storing features: 100%|_______________________________________________________________________________| 90/90 [00:01<00:00, 87.82it/s]
|
||||
2023-07-27 12:41:56,778 INFO [compute_fbank_yesno.py:65] Processing test
|
||||
Extracting and storing features: 100%|______________________________________________________________________________| 30/30 [00:00<00:00, 256.92it/s]
|
||||
2023-07-27 12:41:57 (prepare.sh:51:main) Stage 3: Prepare lang
|
||||
2023-07-27 12:42:02 (prepare.sh:66:main) Stage 4: Prepare G
|
||||
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79
|
||||
[I] Reading \data\ section.
|
||||
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140
|
||||
[I] Reading \1-grams: section.
|
||||
2023-07-27 12:42:02 (prepare.sh:92:main) Stage 5: Compile HLG
|
||||
2023-07-27 12:42:07,275 INFO [compile_hlg.py:124] Processing data/lang_phone
|
||||
2023-07-27 12:42:07,276 INFO [lexicon.py:171] Converting L.pt to Linv.pt
|
||||
2023-07-27 12:42:07,309 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
|
||||
2023-07-27 12:42:07,310 INFO [compile_hlg.py:52] Loading G.fst.txt
|
||||
2023-07-27 12:42:07,314 INFO [compile_hlg.py:62] Intersecting L and G
|
||||
2023-07-27 12:42:07,323 INFO [compile_hlg.py:64] LG shape: (4, None)
|
||||
2023-07-27 12:42:07,323 INFO [compile_hlg.py:66] Connecting LG
|
||||
2023-07-27 12:42:07,323 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
|
||||
2023-07-27 12:42:07,323 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
|
||||
2023-07-27 12:42:07,323 INFO [compile_hlg.py:71] Determinizing LG
|
||||
2023-07-27 12:42:07,341 INFO [compile_hlg.py:74] <class '_k2.ragged.RaggedTensor'>
|
||||
2023-07-27 12:42:07,341 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
|
||||
2023-07-27 12:42:07,341 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
|
||||
2023-07-27 12:42:07,354 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None)
|
||||
2023-07-27 12:42:07,445 INFO [compile_hlg.py:96] Arc sorting LG
|
||||
2023-07-27 12:42:07,445 INFO [compile_hlg.py:99] Composing H and LG
|
||||
2023-07-27 12:42:07,446 INFO [compile_hlg.py:106] Connecting LG
|
||||
2023-07-27 12:42:07,446 INFO [compile_hlg.py:109] Arc sorting LG
|
||||
2023-07-27 12:42:07,447 INFO [compile_hlg.py:111] HLG.shape: (8, None)
|
||||
2023-07-27 12:42:07,447 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone
|
||||
|
||||
Training
|
||||
~~~~~~~~
|
||||
@ -409,12 +435,13 @@ Now let us run the training part:
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ export CUDA_VISIBLE_DEVICES=""
|
||||
$ ./tdnn/train.py
|
||||
(test-icefall) kuangfangjun:ASR$ export CUDA_VISIBLE_DEVICES=""
|
||||
|
||||
(test-icefall) kuangfangjun:ASR$ ./tdnn/train.py
|
||||
|
||||
.. CAUTION::
|
||||
|
||||
We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
|
||||
We use ``export CUDA_VISIBLE_DEVICES=""`` so that `icefall`_ uses CPU
|
||||
even if there are GPUs available.
|
||||
|
||||
.. hint::
|
||||
@ -432,53 +459,52 @@ The training log is given below:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2023-05-12 18:04:59,759 INFO [train.py:481] Training started
|
||||
2023-05-12 18:04:59,759 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0,
|
||||
'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10,
|
||||
'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0,
|
||||
'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2,
|
||||
'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
|
||||
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
|
||||
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
|
||||
'k2-path': 'tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
|
||||
'lhotse-path': 'tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
|
||||
2023-05-12 18:04:59,761 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-05-12 18:04:59,764 INFO [train.py:495] device: cpu
|
||||
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:146] About to get train cuts
|
||||
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:244] About to get train cuts
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:149] About to create train dataset
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:199] Using SingleCutSampler.
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:205] About to create train dataloader
|
||||
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-05-12 18:04:59,986 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:00,352 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:00,691 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames.
|
||||
2023-05-12 18:05:00,996 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:01,217 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames.
|
||||
2023-05-12 18:05:01,251 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt
|
||||
2023-05-12 18:05:01,389 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:01,637 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:01,859 INFO [train.py:444] Epoch 1, validation loss=0.1629, over 18067.00 frames.
|
||||
2023-05-12 18:05:02,094 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.0767, over 2695.00 frames. ], tot_loss[loss=0.118, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:02,350 INFO [train.py:444] Epoch 1, validation loss=0.06778, over 18067.00 frames.
|
||||
2023-05-12 18:05:02,395 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt
|
||||
2023-07-27 12:50:51,936 INFO [train.py:481] Training started
|
||||
2023-07-27 12:50:51,936 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d66.clean', 'torch-version': '1.13.0+cu116', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '3fb0a43-clean', 'icefall-git-date': 'Thu Jul 27 12:36:05 2023', 'icefall-path': '/tmp/icefall', 'k2-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-1-1220091118-57c4d55446-sph26', 'IP address': '10.177.77.20'}}
|
||||
2023-07-27 12:50:51,941 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-07-27 12:50:51,949 INFO [train.py:495] device: cpu
|
||||
2023-07-27 12:50:51,965 INFO [asr_datamodule.py:146] About to get train cuts
|
||||
2023-07-27 12:50:51,965 INFO [asr_datamodule.py:244] About to get train cuts
|
||||
2023-07-27 12:50:51,967 INFO [asr_datamodule.py:149] About to create train dataset
|
||||
2023-07-27 12:50:51,967 INFO [asr_datamodule.py:199] Using SingleCutSampler.
|
||||
2023-07-27 12:50:51,967 INFO [asr_datamodule.py:205] About to create train dataloader
|
||||
2023-07-27 12:50:51,968 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-07-27 12:50:51,968 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-07-27 12:50:52,565 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4
|
||||
2023-07-27 12:50:53,681 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames.], batch size: 4
|
||||
2023-07-27 12:50:54,167 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames.
|
||||
2023-07-27 12:50:55,011 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5
|
||||
2023-07-27 12:50:55,331 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames.
|
||||
2023-07-27 12:50:55,368 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt
|
||||
2023-07-27 12:50:55,633 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ],
|
||||
batch size: 4
|
||||
2023-07-27 12:50:56,242 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames.], batch size: 4
|
||||
2023-07-27 12:50:56,522 INFO [train.py:444] Epoch 1, validation loss=0.1627, over 18067.00 frames.
|
||||
2023-07-27 12:50:57,209 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.07055, over 2695.00 frames. ], tot_loss[loss=0.1175, over 34971.47 frames.], batch size: 5
|
||||
2023-07-27 12:50:57,600 INFO [train.py:444] Epoch 1, validation loss=0.07091, over 18067.00 frames.
|
||||
2023-07-27 12:50:57,640 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt
|
||||
2023-07-27 12:50:57,847 INFO [train.py:422] Epoch 2, batch 0, loss[loss=0.07731, over 2436.00 frames. ], tot_loss[loss=0.07731, over 2436.00 frames.], batch size: 4
|
||||
2023-07-27 12:50:58,427 INFO [train.py:422] Epoch 2, batch 10, loss[loss=0.04391, over 2828.00 frames. ], tot_loss[loss=0.05341, over 22192.90 frames. ], batch size: 4
|
||||
2023-07-27 12:50:58,884 INFO [train.py:444] Epoch 2, validation loss=0.04384, over 18067.00 frames.
|
||||
2023-07-27 12:50:59,387 INFO [train.py:422] Epoch 2, batch 20, loss[loss=0.03458, over 2695.00 frames. ], tot_loss[loss=0.04616, over 34971.47 frames. ], batch size: 5
|
||||
2023-07-27 12:50:59,707 INFO [train.py:444] Epoch 2, validation loss=0.03379, over 18067.00 frames.
|
||||
2023-07-27 12:50:59,758 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-2.pt
|
||||
|
||||
... ...
|
||||
... ...
|
||||
|
||||
2023-05-12 18:05:14,789 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01056, over 2436.00 frames. ], tot_loss[loss=0.01056, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:15,016 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009022, over 2828.00 frames. ], tot_loss[loss=0.009985, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:15,271 INFO [train.py:444] Epoch 13, validation loss=0.01088, over 18067.00 frames.
|
||||
2023-05-12 18:05:15,497 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01174, over 2695.00 frames. ], tot_loss[loss=0.01077, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:15,747 INFO [train.py:444] Epoch 13, validation loss=0.01087, over 18067.00 frames.
|
||||
2023-05-12 18:05:15,783 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt
|
||||
2023-05-12 18:05:15,921 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01045, over 2436.00 frames. ], tot_loss[loss=0.01045, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:16,146 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008957, over 2828.00 frames. ], tot_loss[loss=0.009903, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:16,374 INFO [train.py:444] Epoch 14, validation loss=0.01092, over 18067.00 frames.
|
||||
2023-05-12 18:05:16,598 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01065, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:16,824 INFO [train.py:444] Epoch 14, validation loss=0.01077, over 18067.00 frames.
|
||||
2023-05-12 18:05:16,862 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt
|
||||
2023-05-12 18:05:16,865 INFO [train.py:555] Done!
|
||||
2023-07-27 12:51:23,433 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01054, over 2436.00 frames. ], tot_loss[loss=0.01054, over 2436.00 frames. ], batch size: 4
|
||||
2023-07-27 12:51:23,980 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009014, over 2828.00 frames. ], tot_loss[loss=0.009974, over 22192.90 frames. ], batch size: 4
|
||||
2023-07-27 12:51:24,489 INFO [train.py:444] Epoch 13, validation loss=0.01085, over 18067.00 frames.
|
||||
2023-07-27 12:51:25,258 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01172, over 2695.00 frames. ], tot_loss[loss=0.01055, over 34971.47 frames. ], batch size: 5
|
||||
2023-07-27 12:51:25,621 INFO [train.py:444] Epoch 13, validation loss=0.01074, over 18067.00 frames.
|
||||
2023-07-27 12:51:25,699 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt
|
||||
2023-07-27 12:51:25,866 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01044, over 2436.00 frames. ], tot_loss[loss=0.01044, over 2436.00 frames. ], batch size: 4
|
||||
2023-07-27 12:51:26,844 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008942, over 2828.00 frames. ], tot_loss[loss=0.01, over 22192.90 frames. ], batch size: 4
|
||||
2023-07-27 12:51:27,221 INFO [train.py:444] Epoch 14, validation loss=0.01082, over 18067.00 frames.
|
||||
2023-07-27 12:51:27,970 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01054, over 34971.47 frames. ], batch size: 5
|
||||
2023-07-27 12:51:28,247 INFO [train.py:444] Epoch 14, validation loss=0.01073, over 18067.00 frames.
|
||||
2023-07-27 12:51:28,323 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt
|
||||
2023-07-27 12:51:28,326 INFO [train.py:555] Done!
|
||||
|
||||
Decoding
|
||||
~~~~~~~~
|
||||
@ -487,42 +513,32 @@ Let us use the trained model to decode the test set:
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ ./tdnn/decode.py
|
||||
(test-icefall) kuangfangjun:ASR$ ./tdnn/decode.py
|
||||
|
||||
The decoding log is:
|
||||
2023-07-27 12:55:12,840 INFO [decode.py:263] Decoding started
|
||||
2023-07-27 12:55:12,840 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d66.clean', 'torch-version': '1.13.0+cu116', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '3fb0a43-clean', 'icefall-git-date': 'Thu Jul 27 12:36:05 2023', 'icefall-path': '/tmp/icefall', 'k2-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/test-icefall/lib/python3.8/site-packages/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-1-1220091118-57c4d55446-sph26', 'IP address': '10.177.77.20'}}
|
||||
2023-07-27 12:55:12,841 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-07-27 12:55:12,855 INFO [decode.py:273] device: cpu
|
||||
2023-07-27 12:55:12,868 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
2023-07-27 12:55:12,882 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-07-27 12:55:12,883 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-07-27 12:55:13,157 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
|
||||
2023-07-27 12:55:13,701 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||
2023-07-27 12:55:13,702 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
2023-07-27 12:55:13,704 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||
2023-07-27 12:55:13,704 INFO [decode.py:316] Done!
|
||||
|
||||
.. code-block::
|
||||
|
||||
2023-05-12 18:08:30,482 INFO [decode.py:263] Decoding started
|
||||
2023-05-12 18:08:30,483 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23,
|
||||
'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'),
|
||||
'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True,
|
||||
'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
|
||||
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
|
||||
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
|
||||
'k2-path': '/tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
|
||||
'lhotse-path': '/tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
|
||||
2023-05-12 18:08:30,483 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-05-12 18:08:30,487 INFO [decode.py:273] device: cpu
|
||||
2023-05-12 18:08:30,513 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-05-12 18:08:30,675 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
|
||||
2023-05-12 18:08:30,923 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||
2023-05-12 18:08:30,924 INFO [utils.py:558] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
2023-05-12 18:08:30,925 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||
2023-05-12 18:08:30,925 INFO [decode.py:316] Done!
|
||||
|
||||
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
|
||||
**Congratulations!** You have successfully setup the environment and have run the first recipe in `icefall`_.
|
||||
|
||||
Have fun with ``icefall``!
|
||||
|
||||
YouTube Video
|
||||
-------------
|
||||
|
||||
We provide the following YouTube video showing how to install ``icefall``.
|
||||
We provide the following YouTube video showing how to install `icefall`_.
|
||||
It also shows how to debug various problems that you may encounter while
|
||||
using ``icefall``.
|
||||
using `icefall`_.
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
Distillation with HuBERT
|
||||
========================
|
||||
|
||||
This tutorial shows you how to perform knowledge distillation in `icefall`_
|
||||
This tutorial shows you how to perform knowledge distillation in `icefall <https://github.com/k2-fsa/icefall>`_
|
||||
with the `LibriSpeech`_ dataset. The distillation method
|
||||
used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD).
|
||||
Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation <https://arxiv.org/abs/2211.00508>`_
|
||||
@ -13,7 +13,7 @@ for more details about MVQ-KD.
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_.
|
||||
Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes
|
||||
with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you
|
||||
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -217,7 +217,7 @@ the following command.
|
||||
--exp-dir $exp_dir \
|
||||
--enable-distillation True
|
||||
|
||||
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`_.
|
||||
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`__.
|
||||
|
||||
That's all! Feel free to experiment with your own setups and report your results.
|
||||
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`__.
|
||||
|
@ -8,10 +8,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
|
||||
We will take pruned_transducer_stateless4 as an example in this tutorial.
|
||||
|
||||
.. HINT::
|
||||
@ -237,7 +237,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
|
||||
|
||||
.. NOTE::
|
||||
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
|
||||
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
|
||||
|
||||
|
||||
@ -529,13 +529,13 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following links:
|
||||
|
||||
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`_
|
||||
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`__
|
||||
|
||||
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`_
|
||||
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`__
|
||||
|
||||
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`_
|
||||
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`__
|
||||
|
||||
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`_
|
||||
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`__
|
||||
|
||||
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||
for the details of the above pretrained models
|
||||
|
@ -45,9 +45,9 @@ the input features.
|
||||
|
||||
We have three variants of Emformer models in ``icefall``.
|
||||
|
||||
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`_.
|
||||
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`__.
|
||||
- ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio,
|
||||
ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model.
|
||||
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`_.
|
||||
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`__.
|
||||
- ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that
|
||||
it uses a simplified memory bank. See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless2>`_.
|
||||
|
@ -6,10 +6,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
|
||||
We will take pruned_transducer_stateless4 as an example in this tutorial.
|
||||
|
||||
.. HINT::
|
||||
@ -264,7 +264,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
|
||||
|
||||
.. NOTE::
|
||||
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
|
||||
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`__,
|
||||
|
||||
.. HINT::
|
||||
|
||||
@ -642,7 +642,7 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following links:
|
||||
|
||||
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_
|
||||
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__
|
||||
|
||||
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||
for the details of the above pretrained models
|
||||
|
@ -32,7 +32,7 @@ import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
src_dir = Path("data/manifests/aidatatang_200zh")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -109,7 +110,12 @@ def get_args():
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -119,4 +125,6 @@ if __name__ == "__main__":
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
|
||||
compute_fbank_aidatatang_200zh(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
)
|
||||
|
@ -77,7 +77,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for aidatatang_200zh"
|
||||
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aidatatang_200zh.py
|
||||
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
|
||||
touch data/fbank/.aidatatang_200zh.done
|
||||
fi
|
||||
fi
|
||||
|
@ -32,7 +32,7 @@ import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -109,7 +110,12 @@ def get_args():
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -119,4 +125,6 @@ if __name__ == "__main__":
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
|
||||
compute_fbank_aidatatang_200zh(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell(num_mel_bins: int = 80):
|
||||
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -81,7 +81,8 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -104,7 +105,12 @@ def get_args():
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -114,4 +120,6 @@ if __name__ == "__main__":
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell(num_mel_bins=args.num_mel_bins)
|
||||
compute_fbank_aishell(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
)
|
||||
|
@ -114,7 +114,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank for aishell"
|
||||
if [ ! -f data/fbank/.aishell.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell.py
|
||||
./local/compute_fbank_aishell.py --perturb-speed True
|
||||
touch data/fbank/.aishell.done
|
||||
fi
|
||||
fi
|
||||
|
@ -53,7 +53,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process aidatatang_200zh"
|
||||
if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aidatatang_200zh.py
|
||||
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
|
||||
touch data/fbank/.aidatatang_200zh_fbank.done
|
||||
fi
|
||||
fi
|
||||
|
@ -240,7 +240,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless3/exp",
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
@ -243,7 +243,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless3/exp",
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
@ -32,7 +32,7 @@ import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80):
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -81,7 +81,8 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -104,6 +105,12 @@ def get_args():
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -114,4 +121,6 @@ if __name__ == "__main__":
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell2(num_mel_bins=args.num_mel_bins)
|
||||
compute_fbank_aishell2(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
)
|
||||
|
@ -101,7 +101,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank for aishell2"
|
||||
if [ ! -f data/fbank/.aishell2.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell2.py
|
||||
./local/compute_fbank_aishell2.py --perturb-speed True
|
||||
touch data/fbank/.aishell2.done
|
||||
fi
|
||||
fi
|
||||
|
@ -32,7 +32,7 @@ import torch
|
||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell4(num_mel_bins: int = 80):
|
||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
src_dir = Path("data/manifests/aishell4")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -83,10 +83,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
@ -113,6 +115,12 @@ def get_args():
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -123,4 +131,6 @@ if __name__ == "__main__":
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell4(num_mel_bins=args.num_mel_bins)
|
||||
compute_fbank_aishell4(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
)
|
||||
|
@ -107,7 +107,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py
|
||||
./local/compute_fbank_aishell4.py --perturb-speed True
|
||||
touch data/fbank/.aishell4.done
|
||||
fi
|
||||
fi
|
||||
|
@ -32,7 +32,7 @@ import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
src_dir = Path("data/manifests/alimeeting")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -82,7 +82,8 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -114,6 +115,12 @@ def get_args():
|
||||
default=80,
|
||||
help="""The number of mel bins for Fbank""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -124,4 +131,6 @@ if __name__ == "__main__":
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins)
|
||||
compute_fbank_alimeeting(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
)
|
||||
|
@ -97,7 +97,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed True
|
||||
touch data/fbank/.alimeeting.done
|
||||
fi
|
||||
fi
|
||||
|
@ -25,6 +25,7 @@ It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
@ -39,6 +40,8 @@ from lhotse.features.kaldifeat import (
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -48,7 +51,7 @@ torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_ami():
|
||||
def compute_fbank_ami(perturb_speed: bool = False):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
@ -84,8 +87,12 @@ def compute_fbank_ami():
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
|
||||
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
||||
def _extract_feats(
|
||||
cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool
|
||||
) -> None:
|
||||
if speed_perturb:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
||||
_ = cuts.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=storage_path,
|
||||
@ -109,6 +116,7 @@ def compute_fbank_ami():
|
||||
cuts_ihm,
|
||||
output_dir / "feats_train_ihm",
|
||||
src_dir / "cuts_train_ihm.jsonl.gz",
|
||||
perturb_speed,
|
||||
)
|
||||
|
||||
logging.info("Processing train split IHM + reverberated IHM")
|
||||
@ -117,6 +125,7 @@ def compute_fbank_ami():
|
||||
cuts_ihm_rvb,
|
||||
output_dir / "feats_train_ihm_rvb",
|
||||
src_dir / "cuts_train_ihm_rvb.jsonl.gz",
|
||||
perturb_speed,
|
||||
)
|
||||
|
||||
logging.info("Processing train split SDM")
|
||||
@ -129,6 +138,7 @@ def compute_fbank_ami():
|
||||
cuts_sdm,
|
||||
output_dir / "feats_train_sdm",
|
||||
src_dir / "cuts_train_sdm.jsonl.gz",
|
||||
perturb_speed,
|
||||
)
|
||||
|
||||
logging.info("Processing train split GSS")
|
||||
@ -141,6 +151,7 @@ def compute_fbank_ami():
|
||||
cuts_gss,
|
||||
output_dir / "feats_train_gss",
|
||||
src_dir / "cuts_train_gss.jsonl.gz",
|
||||
perturb_speed,
|
||||
)
|
||||
|
||||
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
|
||||
@ -186,8 +197,21 @@ def compute_fbank_ami():
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ami()
|
||||
args = get_args()
|
||||
|
||||
compute_fbank_ami(perturb_speed=args.perturb_speed)
|
||||
|
@ -85,7 +85,7 @@ fi
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for alimeeting"
|
||||
mkdir -p data/fbank
|
||||
python local/compute_fbank_alimeeting.py
|
||||
python local/compute_fbank_alimeeting.py --perturb-speed True
|
||||
log "Combine features from train splits"
|
||||
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
|
||||
gzip -c > data/manifests/cuts_train_all.jsonl.gz
|
||||
|
156
egs/ami/SURT/README.md
Normal file
156
egs/ami/SURT/README.md
Normal file
@ -0,0 +1,156 @@
|
||||
# Introduction
|
||||
|
||||
This is a multi-talker ASR recipe for the AMI and ICSI datasets. We train a Streaming
|
||||
Unmixing and Recognition Transducer (SURT) model for the task.
|
||||
|
||||
Please refer to the `egs/libricss/SURT` recipe README for details about the task and the
|
||||
model.
|
||||
|
||||
## Description of the recipe
|
||||
|
||||
### Pre-requisites
|
||||
|
||||
The recipes in this directory need the following packages to be installed:
|
||||
|
||||
- [meeteval](https://github.com/fgnt/meeteval)
|
||||
- [einops](https://github.com/arogozhnikov/einops)
|
||||
|
||||
Additionally, we initialize the model with the pre-trained model from the LibriCSS recipe.
|
||||
Please download this checkpoint (see below) or train the LibriCSS recipe first.
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, run the following from within `egs/ami/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
python dprnn_zipformer/train.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 30 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--heat-loss-scale 0.2 \
|
||||
--base-lr 0.004 \
|
||||
--model-init-ckpt exp/libricss_base.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
The above is for SURT-base (~26M). For SURT-large (~38M), use:
|
||||
|
||||
```bash
|
||||
--model-init-ckpt exp/libricss_large.pt \
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt exp/zipformer_large.pt \
|
||||
```
|
||||
|
||||
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
|
||||
|
||||
### Adaptation
|
||||
|
||||
The training step above only trains on simulated mixtures. For best results, we also
|
||||
adapt the final model on the AMI+ICSI train set. For this, run the following from within
|
||||
`egs/ami/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/train_adapt.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 8 \
|
||||
--lr-epochs 2 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--base-lr 0.0004 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
For SURT-large, use the following config:
|
||||
|
||||
```bash
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
|
||||
--num-epochs 15 \
|
||||
--lr-epochs 4 \
|
||||
```
|
||||
|
||||
|
||||
### Decoding
|
||||
|
||||
To decode the model, run the following from within `egs/ami/SURT`:
|
||||
|
||||
#### Greedy search
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 20 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
|
||||
#### Beam search
|
||||
|
||||
```bash
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 20 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
## Results (using beam search)
|
||||
|
||||
**AMI**
|
||||
|
||||
| Model | IHM-Mix | SDM | MDM |
|
||||
|------------|:-------:|:----:|:----:|
|
||||
| SURT-base | 39.8 | 65.4 | 46.6 |
|
||||
| + adapt | 37.4 | 46.9 | 43.7 |
|
||||
| SURT-large | 36.8 | 62.5 | 44.4 |
|
||||
| + adapt | **35.1** | **44.6** | **41.4** |
|
||||
|
||||
**ICSI**
|
||||
|
||||
| Model | IHM-Mix | SDM |
|
||||
|------------|:-------:|:----:|
|
||||
| SURT-base | 28.3 | 60.0 |
|
||||
| + adapt | 26.3 | 33.9 |
|
||||
| SURT-large | 27.8 | 59.7 |
|
||||
| + adapt | **24.4** | **32.3** |
|
||||
|
||||
## Pre-trained models and logs
|
||||
|
||||
* LibriCSS pre-trained model (for initialization): [base](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_base) [large](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_large)
|
||||
|
||||
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-ami-dprnn-zipformer>
|
||||
|
||||
* Training logs:
|
||||
- surt_base: <https://tensorboard.dev/experiment/8awy98VZSWegLmH4l2JWSA/>
|
||||
- surt_base_adapt: <https://tensorboard.dev/experiment/aGVgXVzYRDKbGUbPekcNjg/>
|
||||
- surt_large: <https://tensorboard.dev/experiment/ZXMkez0VSYKbPLqRk4clOQ/>
|
||||
- surt_large_adapt: <https://tensorboard.dev/experiment/WLKL1e7bTVyEjSonYSNYwg/>
|
399
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
399
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,399 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SurtDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AmiAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 SURT experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration-valid",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-cuts",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of cuts in a single batch. You can "
|
||||
"reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help=(
|
||||
"When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
sources: bool = False,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=sources,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
validate = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def aimix_train_cuts(
|
||||
self,
|
||||
rvb_affix: str = "clean",
|
||||
sources: bool = True,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
source_affix = "_sources" if sources else ""
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_train_{rvb_affix}{source_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(
|
||||
self,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_ami_icsi.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def ami_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
|
||||
logging.info(f"About to get AMI {split} {type} cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_ami-{type}_{split}.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def icsi_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
|
||||
logging.info(f"About to get ICSI {split} {type} cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_icsi-{type}_{split}.jsonl.gz"
|
||||
)
|
1
egs/ami/SURT/dprnn_zipformer/beam_search.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/beam_search.py
|
622
egs/ami/SURT/dprnn_zipformer/decode.py
Executable file
622
egs/ami/SURT/dprnn_zipformer/decode.py
Executable file
@ -0,0 +1,622 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AmiAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.utils import EPSILON
|
||||
from train import add_model_arguments, get_params, get_surt_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_surt_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="dprnn_zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
feature_lens = batch["input_lens"].to(device)
|
||||
|
||||
# Apply the mask encoder
|
||||
B, T, F = feature.shape
|
||||
processed = model.mask_encoder(feature) # B,T,F*num_channels
|
||||
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
||||
x_masked = [feature * m for m in masks]
|
||||
|
||||
# Recognition
|
||||
# Stack the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||
|
||||
if model.joint_encoder_layer is not None:
|
||||
encoder_out = model.joint_encoder_layer(encoder_out)
|
||||
|
||||
def _group_channels(hyps: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Currently we have a batch of size M*B, where M is the number of
|
||||
channels and B is the batch size. We need to group the hypotheses
|
||||
into B groups, each of which contains M hypotheses.
|
||||
|
||||
Example:
|
||||
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
|
||||
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
|
||||
"""
|
||||
assert len(hyps) == B * params.num_channels
|
||||
out_hyps = []
|
||||
for i in range(B):
|
||||
out_hyps.append(hyps[i::B])
|
||||
return out_hyps
|
||||
|
||||
hyps = []
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp))
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": _group_channels(hyps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: _group_channels(hyps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||
cuts_batch = batch["cuts"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
for cut_id, hyp_words in zip(cut_ids, hyps):
|
||||
# Reference is a list of supervision texts sorted by start time.
|
||||
ref_words = [
|
||||
s.text.strip()
|
||||
for s in sorted(
|
||||
cuts_batch[cut_id].supervisions, key=lambda s: s.start
|
||||
)
|
||||
]
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(cut_ids)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_surt_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
num_channels=params.num_channels,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LmScorer.add_arguments(parser)
|
||||
AmiAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
), f"Decoding method {params.decoding_method} is not supported."
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_surt_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ami = AmiAsrDataModule(args)
|
||||
|
||||
# NOTE(@desh2608): we filter segments longer than 120s to avoid OOM errors in decoding.
|
||||
# However, 99.9% of the segments are shorter than 120s, so this should not
|
||||
# substantially affect the results. In future, we will implement an overlapped
|
||||
# inference method to avoid OOM errors.
|
||||
|
||||
test_sets = {}
|
||||
for split in ["dev", "test"]:
|
||||
for type in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
test_sets[f"ami-{split}_{type}"] = (
|
||||
ami.ami_cuts(split=split, type=type)
|
||||
.trim_to_supervision_groups(max_pause=0.0)
|
||||
.filter(lambda c: 0.1 < c.duration < 120.0)
|
||||
.to_eager()
|
||||
)
|
||||
|
||||
for split in ["dev", "test"]:
|
||||
for type in ["ihm-mix", "sdm"]:
|
||||
test_sets[f"icsi-{split}_{type}"] = (
|
||||
ami.icsi_cuts(split=split, type=type)
|
||||
.trim_to_supervision_groups(max_pause=0.0)
|
||||
.filter(lambda c: 0.1 < c.duration < 120.0)
|
||||
.to_eager()
|
||||
)
|
||||
|
||||
for test_set, test_cuts in test_sets.items():
|
||||
test_dl = ami.test_dataloaders(test_cuts)
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/ami/SURT/dprnn_zipformer/decoder.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/decoder.py
|
1
egs/ami/SURT/dprnn_zipformer/dprnn.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/dprnn.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/dprnn.py
|
1
egs/ami/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/encoder_interface.py
|
1
egs/ami/SURT/dprnn_zipformer/export.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/export.py
|
1
egs/ami/SURT/dprnn_zipformer/joiner.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/joiner.py
|
1
egs/ami/SURT/dprnn_zipformer/model.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/model.py
|
1
egs/ami/SURT/dprnn_zipformer/optim.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/optim.py
|
1
egs/ami/SURT/dprnn_zipformer/scaling.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/scaling.py
|
1
egs/ami/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/scaling_converter.py
|
1
egs/ami/SURT/dprnn_zipformer/test_model.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/test_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py
|
1420
egs/ami/SURT/dprnn_zipformer/train.py
Executable file
1420
egs/ami/SURT/dprnn_zipformer/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1411
egs/ami/SURT/dprnn_zipformer/train_adapt.py
Executable file
1411
egs/ami/SURT/dprnn_zipformer/train_adapt.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/ami/SURT/dprnn_zipformer/zipformer.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/zipformer.py
|
78
egs/ami/SURT/local/add_source_feats.py
Executable file
78
egs/ami/SURT/local/add_source_feats.py
Executable file
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file adds source features as temporal arrays to the mixture manifests.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def add_source_feats():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
logging.info("Reading mixed cuts")
|
||||
mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz")
|
||||
mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz")
|
||||
|
||||
logging.info("Reading source cuts")
|
||||
source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz")
|
||||
|
||||
logging.info("Adding source features to the mixed cuts")
|
||||
pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features")
|
||||
with CutSet.open_writer(
|
||||
src_dir / "cuts_train_clean_sources.jsonl.gz"
|
||||
) as cut_writer_clean, CutSet.open_writer(
|
||||
src_dir / "cuts_train_reverb_sources.jsonl.gz"
|
||||
) as cut_writer_reverb, LilcomChunkyWriter(
|
||||
output_dir / "feats_train_clean_sources"
|
||||
) as source_feat_writer:
|
||||
for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb):
|
||||
assert cut_reverb.id == cut_clean.id + "_rvb"
|
||||
source_feats = []
|
||||
source_feat_offsets = []
|
||||
cur_offset = 0
|
||||
for sup in sorted(
|
||||
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
|
||||
):
|
||||
source_cut = source_cuts[sup.id]
|
||||
source_feats.append(source_cut.load_features())
|
||||
source_feat_offsets.append(cur_offset)
|
||||
cur_offset += source_cut.num_frames
|
||||
cut_clean.source_feats = source_feat_writer.store_array(
|
||||
cut_clean.id, np.concatenate(source_feats, axis=0)
|
||||
)
|
||||
cut_clean.source_feat_offsets = source_feat_offsets
|
||||
cut_writer_clean.write(cut_clean)
|
||||
# Also write the reverb cut
|
||||
cut_reverb.source_feats = cut_clean.source_feats
|
||||
cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets
|
||||
cut_writer_reverb.write(cut_reverb)
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
add_source_feats()
|
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the synthetically mixed AMI and ICSI
|
||||
train set.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torchaudio
|
||||
from lhotse import (
|
||||
AudioSource,
|
||||
LilcomChunkyWriter,
|
||||
Recording,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
)
|
||||
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||
from lhotse.cut import MixedCut, MixTrack, MultiCut
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed, uuid4
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
set_ffmpeg_torchaudio_info_enabled(False)
|
||||
|
||||
|
||||
def compute_fbank_aimix():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz")
|
||||
|
||||
# only uses RIRs and noises from REVERB challenge
|
||||
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
|
||||
# Apply perturbation to the training cuts
|
||||
logging.info("Applying perturbation to the training cuts")
|
||||
train_cuts_rvb = train_cuts.map(
|
||||
lambda c: augment(
|
||||
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Extracting fbank features for training cuts")
|
||||
_ = train_cuts.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "ai-mix_feats_clean",
|
||||
manifest_path=src_dir / "cuts_train_clean.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
_ = train_cuts_rvb.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "ai-mix_feats_reverb",
|
||||
manifest_path=src_dir / "cuts_train_reverb.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
|
||||
"""
|
||||
Given a mixed cut, this function optionally applies the following augmentations:
|
||||
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
|
||||
- Reverberation using a randomly selected RIR
|
||||
- Adding noise
|
||||
- Perturbing the loudness (in range [-20, -25] dB)
|
||||
"""
|
||||
out_cut = cut.drop_features()
|
||||
|
||||
# Perturb the SNRs (optional)
|
||||
if perturb_snr:
|
||||
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
|
||||
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
|
||||
if i == 0:
|
||||
# Skip the first track since it is the reference
|
||||
continue
|
||||
track.snr = snr
|
||||
|
||||
# Reverberate the cut (optional)
|
||||
if rirs is not None:
|
||||
# Select an RIR at random
|
||||
rir = random.choice(rirs)
|
||||
# Select a channel at random
|
||||
rir_channel = random.choice(list(range(rir.num_channels)))
|
||||
# Reverberate the cut
|
||||
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
|
||||
|
||||
# Add noise (optional)
|
||||
if noises is not None:
|
||||
# Select a noise recording at random
|
||||
noise = random.choice(noises).to_cut()
|
||||
if isinstance(noise, MultiCut):
|
||||
noise = noise.to_mono()[0]
|
||||
# Select an SNR at random
|
||||
snr = random.uniform(10, 30)
|
||||
# Repeat the noise to match the duration of the cut
|
||||
noise = repeat_cut(noise, out_cut.duration)
|
||||
out_cut = MixedCut(
|
||||
id=out_cut.id,
|
||||
tracks=[
|
||||
MixTrack(cut=out_cut, type="MixedCut"),
|
||||
MixTrack(cut=noise, type="DataCut", snr=snr),
|
||||
],
|
||||
)
|
||||
|
||||
# Perturb the loudness (optional)
|
||||
if perturb_loudness:
|
||||
target_loudness = random.uniform(-20, -25)
|
||||
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
|
||||
return out_cut
|
||||
|
||||
|
||||
def repeat_cut(cut, duration):
|
||||
while cut.duration < duration:
|
||||
cut = cut.mix(cut, offset_other_by=cut.duration)
|
||||
return cut.truncate(duration=duration)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
fix_random_seed(42)
|
||||
compute_fbank_aimix()
|
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the AMI dataset.
|
||||
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||
This way we can create arbitrary segmentations later.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_ami():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
manifests[part] = read_manifests_if_cached(
|
||||
dataset_parts=["train", "dev", "test"],
|
||||
output_dir=src_dir,
|
||||
prefix=f"ami-{part}",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
for split in ["train", "dev", "test"]:
|
||||
logging.info(f"Processing {part} {split}")
|
||||
cuts = CutSet.from_manifests(
|
||||
**manifests[part][split]
|
||||
).compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"ami-{part}_{split}_feats",
|
||||
manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ami()
|
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the ICSI dataset.
|
||||
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||
This way we can create arbitrary segmentations later.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_icsi():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for part in ["ihm-mix", "sdm"]:
|
||||
manifests[part] = read_manifests_if_cached(
|
||||
dataset_parts=["train"],
|
||||
output_dir=src_dir,
|
||||
prefix=f"icsi-{part}",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
for part in ["ihm-mix", "sdm"]:
|
||||
for split in ["train"]:
|
||||
logging.info(f"Processing {part} {split}")
|
||||
cuts = CutSet.from_manifests(
|
||||
**manifests[part][split]
|
||||
).compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"icsi-{part}_{split}_feats",
|
||||
manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_icsi()
|
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the trimmed sub-segments which will be
|
||||
used for simulating the training mixtures.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torchaudio
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
set_ffmpeg_torchaudio_info_enabled(False)
|
||||
|
||||
|
||||
def compute_fbank_ihm():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for data in ["ami", "icsi"]:
|
||||
manifests[data] = read_manifests_if_cached(
|
||||
dataset_parts=["train"],
|
||||
output_dir=src_dir,
|
||||
types=["recordings", "supervisions"],
|
||||
prefix=f"{data}-ihm",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
for data in ["ami", "icsi"]:
|
||||
cs = CutSet.from_manifests(**manifests[data]["train"])
|
||||
cs = cs.trim_to_supervisions(keep_overlapping=False)
|
||||
cs = cs.normalize_loudness(target=-23.0, affix_id=False)
|
||||
cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1)
|
||||
_ = cs.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"{data}-ihm_train_feats",
|
||||
manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ihm()
|
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file creates AMI train segments.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import LilcomChunkyWriter, load_manifest_lazy
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.utils import EPSILON, add_durations
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def cut_into_windows(cuts: CutSet, duration: float):
|
||||
"""
|
||||
This function takes a CutSet and cuts each cut into windows of roughly
|
||||
`duration` seconds. By roughly, we mean that we try to adjust for the last supervision
|
||||
that exceeds the duration, or is shorter than the duration.
|
||||
"""
|
||||
res = []
|
||||
with tqdm() as pbar:
|
||||
for cut in cuts:
|
||||
pbar.update(1)
|
||||
sups = cut.index_supervisions()[cut.id]
|
||||
sr = cut.sampling_rate
|
||||
start = 0.0
|
||||
end = duration
|
||||
num_tries = 0
|
||||
while start < cut.duration and num_tries < 2:
|
||||
# Find the supervision that are cut by the window endpoint
|
||||
hitlist = [iv for iv in sups.at(end) if iv.begin < end]
|
||||
# If there are no supervisions, we are done
|
||||
if not hitlist:
|
||||
res.append(
|
||||
cut.truncate(
|
||||
offset=start,
|
||||
duration=add_durations(end, -start, sampling_rate=sr),
|
||||
keep_excessive_supervisions=False,
|
||||
)
|
||||
)
|
||||
# Update the start and end for the next window
|
||||
start = end
|
||||
end = add_durations(end, duration, sampling_rate=sr)
|
||||
else:
|
||||
# find ratio of durations cut by the window endpoint
|
||||
ratios = [
|
||||
add_durations(end, -iv.end, sampling_rate=sr) / iv.length()
|
||||
for iv in hitlist
|
||||
]
|
||||
# we retain the supervisions that have >50% of their duration
|
||||
# in the window, and discard the others
|
||||
retained = []
|
||||
discarded = []
|
||||
for iv, ratio in zip(hitlist, ratios):
|
||||
if ratio > 0.5:
|
||||
retained.append(iv)
|
||||
else:
|
||||
discarded.append(iv)
|
||||
cur_end = max(iv.end for iv in retained) if retained else end
|
||||
res.append(
|
||||
cut.truncate(
|
||||
offset=start,
|
||||
duration=add_durations(cur_end, -start, sampling_rate=sr),
|
||||
keep_excessive_supervisions=False,
|
||||
)
|
||||
)
|
||||
# For the next window, we start at the earliest discarded supervision
|
||||
next_start = min(iv.begin for iv in discarded) if discarded else end
|
||||
next_end = add_durations(next_start, duration, sampling_rate=sr)
|
||||
# It may happen that next_start is the same as start, in which case
|
||||
# we will advance the window anyway
|
||||
if next_start == start:
|
||||
logging.warning(
|
||||
f"Next start is the same as start: {next_start} == {start} for cut {cut.id}"
|
||||
)
|
||||
start = end + EPSILON
|
||||
end = add_durations(start, duration, sampling_rate=sr)
|
||||
num_tries += 1
|
||||
else:
|
||||
start = next_start
|
||||
end = next_end
|
||||
return CutSet.from_cuts(res)
|
||||
|
||||
|
||||
def prepare_train_cuts():
|
||||
src_dir = Path("data/manifests")
|
||||
|
||||
logging.info("Loading the manifests")
|
||||
train_cuts_ihm = load_manifest_lazy(
|
||||
src_dir / "cuts_ami-ihm-mix_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map(
|
||||
lambda c: c.with_id(f"{c.id}_sdm")
|
||||
)
|
||||
train_cuts_mdm = load_manifest_lazy(
|
||||
src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_mdm8-bf"))
|
||||
|
||||
# Combine all cuts into one CutSet
|
||||
train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm
|
||||
|
||||
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||
|
||||
# Combine the two segmentations
|
||||
train_all = train_cuts_1 + train_cuts_2
|
||||
|
||||
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||
# roughly 30 seconds.
|
||||
logging.info("Cutting the segments into windows of 30 seconds")
|
||||
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||
|
||||
# Show statistics
|
||||
train_all.describe(full=True)
|
||||
|
||||
# Save the cuts
|
||||
logging.info("Saving the cuts")
|
||||
train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_train_cuts()
|
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file creates ICSI train segments.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
from prepare_ami_train_cuts import cut_into_windows
|
||||
|
||||
|
||||
def prepare_train_cuts():
|
||||
src_dir = Path("data/manifests")
|
||||
|
||||
logging.info("Loading the manifests")
|
||||
train_cuts_ihm = load_manifest_lazy(
|
||||
src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map(
|
||||
lambda c: c.with_id(f"{c.id}_sdm")
|
||||
)
|
||||
|
||||
# Combine all cuts into one CutSet
|
||||
train_cuts = train_cuts_ihm + train_cuts_sdm
|
||||
|
||||
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||
|
||||
# Combine the two segmentations
|
||||
train_all = train_cuts_1 + train_cuts_2
|
||||
|
||||
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||
# roughly 30 seconds.
|
||||
logging.info("Cutting the segments into windows of 30 seconds")
|
||||
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||
|
||||
# Show statistics
|
||||
train_all.describe(full=True)
|
||||
|
||||
# Save the cuts
|
||||
logging.info("Saving the cuts")
|
||||
train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_train_cuts()
|
1
egs/ami/SURT/local/prepare_lang_bpe.py
Symbolic link
1
egs/ami/SURT/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
1
egs/ami/SURT/local/train_bpe_model.py
Symbolic link
1
egs/ami/SURT/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
195
egs/ami/SURT/prepare.sh
Executable file
195
egs/ami/SURT/prepare.sh
Executable file
@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/ami
|
||||
# You can find audio and transcripts for AMI in this path.
|
||||
#
|
||||
# - $dl_dir/icsi
|
||||
# You can find audio and transcripts for ICSI in this path.
|
||||
#
|
||||
# - $dl_dir/rirs_noises
|
||||
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
|
||||
#
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
vocab_size=500
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/amicorpus,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/amicorpus $dl_dir/amicorpus
|
||||
#
|
||||
if [ ! -d $dl_dir/amicorpus ]; then
|
||||
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||
lhotse download ami --mic $mic $dl_dir/amicorpus
|
||||
done
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/icsi,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/icsi $dl_dir/icsi
|
||||
#
|
||||
if [ ! -d $dl_dir/icsi ]; then
|
||||
lhotse download icsi $dl_dir/icsi
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/rirs_noises,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/rirs_noises $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/rirs_noises ]; then
|
||||
lhotse download rirs_noises $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare AMI manifests"
|
||||
# We assume that you have downloaded the AMI corpus
|
||||
# to $dl_dir/amicorpus. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||
log "Preparing AMI manifest for $mic"
|
||||
lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare ICSI manifests"
|
||||
# We assume that you have downloaded the ICSI corpus
|
||||
# to $dl_dir/icsi. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
log "Preparing ICSI manifest"
|
||||
for mic in ihm ihm-mix sdm; do
|
||||
lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare RIRs"
|
||||
# We assume that you have downloaded the RIRS_NOISES corpus
|
||||
# to $dl_dir/rirs_noises
|
||||
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 3: Extract features for AMI and ICSI recordings"
|
||||
python local/compute_fbank_ami.py
|
||||
python local/compute_fbank_icsi.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Create sources for simulating mixtures"
|
||||
# In the following script, we speed-perturb the IHM recordings and extract features.
|
||||
python local/compute_fbank_ihm.py
|
||||
lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \
|
||||
data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\
|
||||
lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\
|
||||
lhotse filter 'duration<=12.0' - - |\
|
||||
shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Create training mixtures"
|
||||
lhotse workflows simulate-meetings \
|
||||
--method conversational \
|
||||
--same-spk-pause 0.5 \
|
||||
--diff-spk-pause 0.5 \
|
||||
--diff-spk-overlap 1.0 \
|
||||
--prob-diff-spk-overlap 0.8 \
|
||||
--num-meetings 200000 \
|
||||
--num-speakers-per-meeting 2,3 \
|
||||
--max-duration-per-speaker 15.0 \
|
||||
--max-utterances-per-speaker 3 \
|
||||
--seed 1234 \
|
||||
--num-jobs 2 \
|
||||
data/manifests/ihm_cuts_train_trimmed.jsonl.gz \
|
||||
data/manifests/ai-mix_cuts_clean.jsonl.gz
|
||||
|
||||
python local/compute_fbank_aimix.py
|
||||
|
||||
# Add source features to the manifest (will be used for masking loss)
|
||||
# This may take ~2 hours.
|
||||
python local/add_source_feats.py
|
||||
|
||||
# Combine clean and reverb
|
||||
cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Create training mixtures from real sessions"
|
||||
python local/prepare_ami_train_cuts.py
|
||||
python local/prepare_icsi_train_cuts.py
|
||||
|
||||
# Combine AMI and ICSI
|
||||
cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)."
|
||||
mkdir -p data/lm
|
||||
cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||
<(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||
> data/lm/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)"
|
||||
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
# Add special words to words.txt
|
||||
echo "<eps> 0" > $lang_dir/words.txt
|
||||
echo "!SIL 1" >> $lang_dir/words.txt
|
||||
echo "<UNK> 2" >> $lang_dir/words.txt
|
||||
|
||||
# Add regular words to words.txt
|
||||
cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
|
||||
|
||||
# Add remaining special word symbols expected by LM scripts.
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "<s> ${num_words}" >> $lang_dir/words.txt
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "</s> ${num_words}" >> $lang_dir/words.txt
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "#0 ${num_words}" >> $lang_dir/words.txt
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript data/lm/transcript_words.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
fi
|
1
egs/ami/SURT/shared
Symbolic link
1
egs/ami/SURT/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
@ -45,7 +45,7 @@ def get_args():
|
||||
|
||||
def normalize_text(utt: str) -> str:
|
||||
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
|
||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||
return re.sub(r"[^a-zA-Z\s']", "", utt).upper()
|
||||
|
||||
|
||||
def preprocess_commonvoice(
|
||||
|
249
egs/libricss/SURT/README.md
Normal file
249
egs/libricss/SURT/README.md
Normal file
@ -0,0 +1,249 @@
|
||||
# Introduction
|
||||
|
||||
This is a multi-talker ASR recipe for the LibriCSS dataset. We train a Streaming
|
||||
Unmixing and Recognition Transducer (SURT) model for the task. In this README,
|
||||
we will describe the task, the model, and the training process. We will also
|
||||
provide links to pre-trained models and training logs.
|
||||
|
||||
## Task
|
||||
|
||||
LibriCSS is a multi-talker meeting corpus formed from mixing together LibriSpeech utterances
|
||||
and replaying in a real meeting room. It consists of 10 1-hour sessions of audio, each
|
||||
recorded on a 7-channel microphone. The sessions are recorded at a sampling rate of 16 kHz.
|
||||
For more information, refer to the paper:
|
||||
Z. Chen et al., "Continuous speech separation: dataset and analysis,"
|
||||
ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),
|
||||
Barcelona, Spain, 2020
|
||||
|
||||
In this recipe, we perform the "continuous, streaming, multi-talker ASR" task on LibriCSS.
|
||||
|
||||
* By "continuous", we mean that the model should be able to transcribe unsegmented audio
|
||||
without the need of an external VAD.
|
||||
* By "streaming", we mean that the model has limited right context. We use a right-context
|
||||
of at most 32 frames (320 ms).
|
||||
* By "multi-talker", we mean that the model should be able to transcribe overlapping speech
|
||||
from multiple speakers.
|
||||
|
||||
For now, we do not care about speaker attribution, i.e., the transcription is speaker
|
||||
agnostic. The evaluation depends on the particular model type. In this case, we use
|
||||
the optimal reference combination WER (ORC-WER) metric as implemented in the
|
||||
[meeteval](https://github.com/fgnt/meeteval) toolkit.
|
||||
|
||||
## Model
|
||||
|
||||
We use the Streaming Unmixing and Recognition Transducer (SURT) model for this task.
|
||||
The model is based on the papers:
|
||||
|
||||
- Lu, Liang et al. “Streaming End-to-End Multi-Talker Speech Recognition.” IEEE Signal Processing Letters 28 (2020): 803-807.
|
||||
- Raj, Desh et al. “Continuous Streaming Multi-Talker ASR with Dual-Path Transducers.” ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2021): 7317-7321.
|
||||
|
||||
The model is a combination of a speech separation model and a speech recognition model,
|
||||
but trained end-to-end with a single loss function. The overall architecture is shown
|
||||
in the figure below. Note that this architecture is slightly different from the one
|
||||
in the above papers. A detailed description of the model can be found in the following
|
||||
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
|
||||
|
||||
<p align="center">
|
||||
|
||||
<img src="surt.png">
|
||||
Streaming Unmixing and Recognition Transducer
|
||||
|
||||
</p>
|
||||
|
||||
In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
|
||||
and a Zipfomer-based recognition network. But other combinations are possible as well.
|
||||
|
||||
## Training objective
|
||||
|
||||
We train the model using the pruned transducer loss, similar to other ASR recipes in
|
||||
icefall. However, an important consideration is how to assign references to the output
|
||||
channels (2 in this case). For this, we use the heuristic error assignment training (HEAT)
|
||||
strategy, which assigns references to the first available channel based on their start
|
||||
times. An illustrative example is shown in the figure below:
|
||||
|
||||
<p align="center">
|
||||
|
||||
<img src="heat.png">
|
||||
Illustration of HEAT-based reference assignment.
|
||||
|
||||
</p>
|
||||
|
||||
## Description of the recipe
|
||||
|
||||
### Pre-requisites
|
||||
|
||||
The recipes in this directory need the following packages to be installed:
|
||||
|
||||
- [meeteval](https://github.com/fgnt/meeteval)
|
||||
- [einops](https://github.com/arogozhnikov/einops)
|
||||
|
||||
Additionally, we initialize the "recognition" transducer with a pre-trained model,
|
||||
trained on LibriSpeech. For this, please run the following from within `egs/librispeech/ASR`:
|
||||
|
||||
```bash
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
python pruned_transducer_stateless7_streaming/train.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir pruned_transducer_stateless7_streaming/exp \
|
||||
--world-size 4 \
|
||||
--max-duration 800 \
|
||||
--num-epochs 10 \
|
||||
--keep-last-k 1 \
|
||||
--manifest-dir data/manifests \
|
||||
--enable-musan true \
|
||||
--master-port 54321 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--num-encoder-layers 2,2,2,2,2 \
|
||||
--feedforward-dims 768,768,768,768,768 \
|
||||
--nhead 8,8,8,8,8 \
|
||||
--encoder-dims 256,256,256,256,256 \
|
||||
--attention-dims 192,192,192,192,192 \
|
||||
--encoder-unmasked-dims 192,192,192,192,192 \
|
||||
--zipformer-downsampling-factors 1,2,4,8,2 \
|
||||
--cnn-module-kernels 31,31,31,31,31 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
```
|
||||
|
||||
The above is for SURT-base (~26M). For SURT-large (~38M), use `--num-encoder-layers 2,4,3,2,4`.
|
||||
|
||||
Once the above model is trained for 10 epochs, copy it to `egs/libricss/SURT/exp`:
|
||||
|
||||
```bash
|
||||
cp -r pruned_transducer_stateless7_streaming/exp/epoch-10.pt exp/zipformer_base.pt
|
||||
```
|
||||
|
||||
**NOTE:** We also provide this pre-trained checkpoint (see the section below), so you can skip
|
||||
the above step if you want.
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, run the following from within `egs/libricss/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
python dprnn_zipformer/train.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 30 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--heat-loss-scale 0.2 \
|
||||
--base-lr 0.004 \
|
||||
--model-init-ckpt exp/zipformer_base.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
The above is for SURT-base (~26M). For SURT-large (~38M), use:
|
||||
|
||||
```bash
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt exp/zipformer_large.pt \
|
||||
```
|
||||
|
||||
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
|
||||
|
||||
### Adaptation
|
||||
|
||||
The training step above only trains on simulated mixtures. For best results, we also
|
||||
adapt the final model on the LibriCSS dev set. For this, run the following from within
|
||||
`egs/libricss/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/train_adapt.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--world-size 1 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 8 \
|
||||
--lr-epochs 2 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--base-lr 0.0004 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
For SURT-large, use the following config:
|
||||
|
||||
```bash
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
|
||||
--num-epochs 15 \
|
||||
--lr-epochs 4 \
|
||||
```
|
||||
|
||||
|
||||
### Decoding
|
||||
|
||||
To decode the model, run the following from within `egs/libricss/SURT`:
|
||||
|
||||
#### Greedy search
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 8 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
|
||||
#### Beam search
|
||||
|
||||
```bash
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 8 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
## Results (using beam search)
|
||||
|
||||
#### IHM-Mix
|
||||
|
||||
| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
|
||||
|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
|
||||
| dprnn_zipformer (base) | 26.7 | 5.1 | 4.2 | 13.7 | 18.7 | 20.5 | 20.6 | 13.8 |
|
||||
| dprnn_zipformer (large) | 37.9 | 4.6 | 3.8 | 12.7 | 14.3 | 16.7 | 21.2 | 12.2 |
|
||||
|
||||
#### SDM
|
||||
|
||||
| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
|
||||
|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
|
||||
| dprnn_zipformer (base) | 26.7 | 6.8 | 7.2 | 21.4 | 24.5 | 28.6 | 31.2 | 20.0 |
|
||||
| dprnn_zipformer (large) | 37.9 | 6.4 | 6.9 | 17.9 | 19.7 | 25.2 | 25.5 | 16.9 |
|
||||
|
||||
## Pre-trained models and logs
|
||||
|
||||
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer>
|
||||
|
||||
* Training logs:
|
||||
- surt_base: <https://tensorboard.dev/experiment/YLGJTkBETb2aqDQ61jbxvQ/>
|
||||
- surt_base_adapt: <https://tensorboard.dev/experiment/pjXMFVL9RMej85rMHyd0EQ/>
|
||||
- surt_large: <https://tensorboard.dev/experiment/82HvYqfrSOKZ4w8Jod2QMw/>
|
||||
- surt_large_adapt: <https://tensorboard.dev/experiment/5oIdEgRqS9Wb6yVuxaExEw/>
|
372
egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
372
egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,372 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
# Copyright 2023 Johns Hopkins Univrtsity (Author: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SurtDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriCssAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration-valid",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-cuts",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of cuts in a single batch. You can "
|
||||
"reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help=(
|
||||
"When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
return_sources: bool = True,
|
||||
strict: bool = True,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=return_sources,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
validate = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def lsmix_cuts(
|
||||
self,
|
||||
rvb_affix: str = "clean",
|
||||
type_affix: str = "full",
|
||||
sources: bool = True,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
source_affix = "_sources" if sources else ""
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir
|
||||
/ f"cuts_train_{rvb_affix}_{type_affix}{source_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
|
||||
logging.info(f"About to get LibriCSS {split} {type} cuts")
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
|
||||
)
|
||||
return cs
|
730
egs/libricss/SURT/dprnn_zipformer/beam_search.py
Normal file
730
egs/libricss/SURT/dprnn_zipformer/beam_search.py
Normal file
@ -0,0 +1,730 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from model import SURT
|
||||
|
||||
from icefall import NgramLmStateCost
|
||||
from icefall.utils import DecodingResults
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: SURT,
|
||||
encoder_out: torch.Tensor,
|
||||
max_sym_per_frame: int,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""Greedy search for a single utterance.
|
||||
Args:
|
||||
model:
|
||||
An instance of `SURT`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
max_sym_per_frame:
|
||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||
would be 100%.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 4
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
hyp = [blank_id] * context_size
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which hyp[i] is decoded
|
||||
timestamp = []
|
||||
|
||||
# Maximum symbols per utterance.
|
||||
max_sym_per_utt = 1000
|
||||
|
||||
# symbols per frame
|
||||
sym_per_frame = 0
|
||||
|
||||
# symbols per utterance decoded so far
|
||||
sym_per_utt = 0
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
if sym_per_frame >= max_sym_per_frame:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
if y not in (blank_id, unk_id):
|
||||
hyp.append(y)
|
||||
timestamp.append(t)
|
||||
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
sym_per_utt += 1
|
||||
sym_per_frame += 1
|
||||
else:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
hyp = hyp[context_size:] # remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return hyp
|
||||
else:
|
||||
return DecodingResults(
|
||||
hyps=[hyp],
|
||||
timestamps=[timestamp],
|
||||
)
|
||||
|
||||
|
||||
def greedy_search_batch(
|
||||
model: SURT,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The SURT model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
|
||||
|
||||
# timestamp[n][i] is the frame index after subsampling
|
||||
# on which hyp[n][i] is decoded
|
||||
timestamps = [[] for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out: (N, 1, decoder_out_dim)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v not in (blank_id, unk_id):
|
||||
hyps[i].append(v)
|
||||
timestamps[i].append(t)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
hyps=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: SURT,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The SURT model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
hyps=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: SURT,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
espnet/nets/beam_search_SURT.py#L247 is used as a reference.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `SURT`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
sym_per_utt = 0
|
||||
|
||||
decoder_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
A = B
|
||||
B = HypothesisList()
|
||||
|
||||
joint_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
||||
# of hypotheses in A
|
||||
|
||||
while True:
|
||||
y_star = A.get_most_probable()
|
||||
A.remove(y_star)
|
||||
|
||||
cached_key = y_star.key
|
||||
|
||||
if cached_key not in decoder_cache:
|
||||
decoder_input = torch.tensor(
|
||||
[y_star.ys[-context_size:]],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
decoder_cache[cached_key] = decoder_out
|
||||
else:
|
||||
decoder_out = decoder_cache[cached_key]
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
log_prob = (logits / temperature).log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
log_prob = log_prob.squeeze()
|
||||
# Now log_prob is (vocab_size,)
|
||||
joint_cache[cached_key] = log_prob
|
||||
else:
|
||||
log_prob = joint_cache[cached_key]
|
||||
|
||||
# First, process the blank symbol
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=y_star.ys[:],
|
||||
log_prob=new_y_star_log_prob,
|
||||
timestamp=y_star.timestamp[:],
|
||||
)
|
||||
)
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
for i, v in zip(indices.tolist(), values.tolist()):
|
||||
if i in (blank_id, unk_id):
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
new_timestamp = y_star.timestamp + [t]
|
||||
A.add(
|
||||
Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
A_most_probable = A.get_most_probable()
|
||||
|
||||
kept_B = B.filter(A_most_probable.log_prob)
|
||||
|
||||
if len(kept_B) >= beam:
|
||||
B = kept_B.topk(beam)
|
||||
break
|
||||
|
||||
t += 1
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
# The predicted tokens so far.
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which ys[i] is decoded
|
||||
timestamp: List[int] = field(default_factory=list)
|
||||
|
||||
# the lm score for next token given the current ys
|
||||
lm_score: Optional[torch.Tensor] = None
|
||||
|
||||
# the RNNLM states (h and c in LSTM)
|
||||
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
|
||||
# N-gram LM state
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
return "_".join(map(str, self.ys))
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
data:
|
||||
A dict of Hypotheses. Its key is its `value.key`.
|
||||
"""
|
||||
if data is None:
|
||||
self._data = {}
|
||||
else:
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
`log-sum-exp` with the existed one.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be added.
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||
"""Get the most probable hypothesis, i.e., the one with
|
||||
the largest `log_prob`.
|
||||
|
||||
Args:
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
|
||||
else:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
|
||||
Caution:
|
||||
`self` is modified **in-place**.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be removed from `self`.
|
||||
Note: It must be contained in `self`. Otherwise,
|
||||
an exception is raised.
|
||||
"""
|
||||
key = hyp.key
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
|
||||
def topk(self, k: int) -> "HypothesisList":
|
||||
"""Return the top-k hypothesis."""
|
||||
hyps = list(self._data.items())
|
||||
|
||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||
|
||||
ans = HypothesisList(dict(hyps))
|
||||
return ans
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self._data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = []
|
||||
for key in self:
|
||||
s.append(key)
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
"""Return a ragged shape with axes [utt][num_hyps].
|
||||
|
||||
Args:
|
||||
hyps:
|
||||
len(hyps) == batch_size. It contains the current hypothesis for
|
||||
each utterance in the batch.
|
||||
Returns:
|
||||
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
|
||||
the shape is on CPU.
|
||||
"""
|
||||
num_hyps = [len(h) for h in hyps]
|
||||
|
||||
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
|
||||
# to get exclusive sum later.
|
||||
num_hyps.insert(0, 0)
|
||||
|
||||
num_hyps = torch.tensor(num_hyps)
|
||||
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
|
||||
ans = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
|
||||
)
|
||||
return ans
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user