Merge remote-tracking branch 'upstream/master' into gigaspeech_rnnt

This commit is contained in:
Guanbo Wang 2022-05-03 18:21:27 -04:00
commit b747e6a43d
79 changed files with 7791 additions and 762 deletions

View File

@ -10,6 +10,8 @@ per-file-ignores =
egs/gigaspeech/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
egs/gigaspeech/ASR/pruned_transducer_stateless2/*.py: E501,
egs/librispeech/ASR/*/optim.py: E501,
egs/librispeech/ASR/*/scaling.py: E501,
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605

View File

@ -0,0 +1,47 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
log "$method"
./pruned_transducer_stateless/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done

View File

@ -0,0 +1,51 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-epoch-38-avg-10.pt pretrained.pt
popd
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./pruned_transducer_stateless2/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./pruned_transducer_stateless2/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done

View File

@ -0,0 +1,51 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-epoch-25-avg-6.pt pretrained.pt
popd
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./pruned_transducer_stateless3/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./pruned_transducer_stateless3/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done

View File

@ -0,0 +1,47 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./transducer_stateless2/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
log "$method"
./transducer_stateless2/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done

View File

@ -0,0 +1,46 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
git lfs install
git clone $repo
log "Downloading pre-trained model from $repo_url"
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.flac
ls -lh $repo/test_wavs/*.flac
log "CTC decoding"
./conformer_ctc/pretrained.py \
--method ctc-decoding \
--num-classes 500 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac
log "HLG decoding"
./conformer_ctc/pretrained.py \
--method 1best \
--num-classes 500 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--words-file $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
$repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac

View File

@ -0,0 +1,47 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
log "$method"
./transducer_stateless_multi_datasets/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done

View File

@ -0,0 +1,47 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
log "$method"
./transducer_stateless_multi_datasets/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done

View File

@ -0,0 +1,47 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/aishell/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./transducer_stateless_modified-2/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$repo/test_wavs/BAC009S0764W0123.wav
done
for method in modified_beam_search beam_search; do
log "$method"
./transducer_stateless_modified-2/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$repo/test_wavs/BAC009S0764W0123.wav
done

View File

@ -0,0 +1,47 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/aishell/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./transducer_stateless_modified/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$repo/test_wavs/BAC009S0764W0123.wav
done
for method in modified_beam_search beam_search; do
log "$method"
./transducer_stateless_modified/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$repo/test_wavs/BAC009S0764W0123.wav
done

View File

@ -0,0 +1,60 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
log "$method"
./transducer_stateless/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
log "$method"
./transducer_stateless_multi_datasets/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done

32
.github/scripts/run-pre-trained-transducer.sh vendored Executable file
View File

@ -0,0 +1,32 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
log "Beam search decoding"
./transducer/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -40,11 +40,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -77,104 +72,11 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs
mkdir -p ~/tmp
cd ~/tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
- name: Display test files
shell: bash
run: |
sudo apt-get -qq install tree sox
tree ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
soxi ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav
ls -lh ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh

View File

@ -0,0 +1,85 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-2022-04-29
# stateless pruned transducer (reworked model) + giga speech
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_librispeech_2022_04_29:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh

View File

@ -0,0 +1,82 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-2022-04-19
# stateless transducer + torchaudio rnn-t loss
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_librispeech_2022_04_19:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh

View File

@ -39,11 +39,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -76,48 +71,11 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
cd ..
tree tmp
soxi tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
ls -lh tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
- name: Run CTC decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--num-classes 500 \
--checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/bpe.model \
--method ctc-decoding \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac
- name: Run HLG decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--num-classes 500 \
--checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
--words-file ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/words.txt \
--HLG ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/HLG.pt \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac
.github/scripts/run-pre-trained-conformer-ctc.sh

View File

@ -39,11 +39,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -76,97 +71,11 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh

View File

@ -39,11 +39,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -76,99 +71,11 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless_multi_datasets/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh

View File

@ -39,11 +39,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -76,98 +71,11 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/aishell/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01
cd ..
tree tmp
soxi tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav
ls -lh tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified-2/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh

View File

@ -39,11 +39,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -76,98 +71,11 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/aishell/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
cd ..
tree tmp
soxi tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav
ls -lh tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/aishell/ASR
./transducer_stateless_modified/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh

View File

@ -39,11 +39,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -76,96 +71,11 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
.github/scripts/run-pre-trained-transducer-stateless.sh

View File

@ -39,11 +39,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -76,48 +71,11 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Download pre-trained model
- name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
- name: Run greedy search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer/pretrained.py \
--method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
cd egs/librispeech/ASR
./transducer/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
.github/scripts/run-pre-trained-transducer.sh

View File

@ -35,6 +35,9 @@ We do provide a Colab notebook for this recipe.
### LibriSpeech
Please see <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>
for the **latest** results.
We provide 4 models for this recipe:
- [conformer CTC model][LibriSpeech_conformer_ctc]
@ -92,6 +95,20 @@ in the decoding.
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
#### k2 pruned RNN-T
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.57 | 5.95 |
#### k2 pruned RNN-T + GigaSpeech
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.19 | 4.97 |
### Aishell
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]

1
egs/librispeech/ASR/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
log-*

View File

@ -1,8 +1,8 @@
# Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech/index.html>
for how to run models in this recipe.
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech/index.html> for how to run models in this recipe.
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
@ -12,11 +12,13 @@ The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
| `transducer` | Conformer | LSTM | |
| `transducer_stateless` | Conformer | Embedding + Conv1d | |
| `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss |
| `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss |
| `transducer_lstm` | LSTM | LSTM | |
| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss |
| `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data |
The decoder in `transducer_stateless` is modified from the paper

View File

@ -1,5 +1,157 @@
## Results
### LibriSpeech BPE training results (Pruned Transducer 3)
[pruned_transducer_stateless3](./pruned_transducer_stateless3)
Same as `Pruned Transducer 2` but using the XL subset from
[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data.
During training, it selects either a batch from GigaSpeech with prob `giga_prob`
or a batch from LibriSpeech with prob `1 - giga_prob`. All utterances within
a batch come from the same dataset.
Using commit `ac84220de91dee10c00e8f4223287f937b1930b6`.
See <https://github.com/k2-fsa/icefall/pull/312>.
The WERs are:
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|----------------------------------------|
| greedy search (max sym per frame 1) | 2.21 | 5.09 | --epoch 27 --avg 2 --max-duration 600 |
| greedy search (max sym per frame 1) | 2.25 | 5.02 | --epoch 27 --avg 12 --max-duration 600 |
| modified beam search | 2.19 | 5.03 | --epoch 25 --avg 6 --max-duration 600 |
| modified beam search | 2.23 | 4.94 | --epoch 27 --avg 10 --max-duration 600 |
| beam search | 2.16 | 4.95 | --epoch 25 --avg 7 --max-duration 600 |
| fast beam search | 2.21 | 4.96 | --epoch 27 --avg 10 --max-duration 600 |
| fast beam search | 2.19 | 4.97 | --epoch 27 --avg 12 --max-duration 600 |
The training commands are:
```bash
./prepare.sh
./prepare_giga_speech.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless3/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 0 \
--full-libri 1 \
--exp-dir pruned_transducer_stateless3/exp \
--max-duration 300 \
--use-fp16 1 \
--lr-epochs 4 \
--num-workers 2 \
--giga-prob 0.8
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/gaD34WeYSMCOkzoo3dZXGg/>
(Note: The training process is killed manually after saving `epoch-28.pt`.)
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29>
The decoding commands are:
```bash
# greedy search
./pruned_transducer_stateless3/decode.py \
--epoch 27 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search \
--max-sym-per-frame 1
# modified beam search
./pruned_transducer_stateless3/decode.py \
--epoch 25 \
--avg 6 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--max-sym-per-frame 1
# beam search
./pruned_transducer_stateless3/decode.py \
--epoch 25 \
--avg 7 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method beam_search \
--max-sym-per-frame 1
# fast beam search
for epoch in 27; do
for avg in 10 12; do
./pruned_transducer_stateless3/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--max-states 32 \
--beam 8
done
done
```
The following table shows the
[Nbest oracle WER](http://kaldi-asr.org/doc/lattices.html#lattices_operations_oracle)
for fast beam search.
| epoch | avg | num_paths | nbest_scale | test-clean | test-other |
|-------|-----|-----------|-------------|------------|------------|
| 27 | 10 | 50 | 0.5 | 0.91 | 2.74 |
| 27 | 10 | 50 | 0.8 | 0.94 | 2.82 |
| 27 | 10 | 50 | 1.0 | 1.06 | 2.88 |
| 27 | 10 | 100 | 0.5 | 0.82 | 2.58 |
| 27 | 10 | 100 | 0.8 | 0.92 | 2.65 |
| 27 | 10 | 100 | 1.0 | 0.95 | 2.77 |
| 27 | 10 | 200 | 0.5 | 0.81 | 2.50 |
| 27 | 10 | 200 | 0.8 | 0.85 | 2.56 |
| 27 | 10 | 200 | 1.0 | 0.91 | 2.64 |
| 27 | 10 | 400 | 0.5 | N/A | N/A |
| 27 | 10 | 400 | 0.8 | 0.81 | 2.49 |
| 27 | 10 | 400 | 1.0 | 0.85 | 2.54 |
The Nbest oracle WER is computed using the following steps:
- 1. Use `fast_beam_search` to produce a lattice.
- 2. Extract `N` paths from the lattice using [k2.random_path](https://k2-fsa.github.io/k2/python_api/api.html#random-paths)
- 3. [Unique](https://k2-fsa.github.io/k2/python_api/api.html#unique) paths so that each path
has a distinct sequence of tokens
- 4. Compute the edit distance of each path with the ground truth
- 5. The path with the lowest edit distance is the final output and is used to
compute the WER
The command to compute the Nbest oracle WER is:
```bash
for epoch in 27; do
for avg in 10 ; do
for num_paths in 50 100 200 400; do
for nbest_scale in 0.5 0.8 1.0; do
./pruned_transducer_stateless3/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--num-paths $num_paths \
--max-states 32 \
--beam 8 \
--nbest-scale $nbest_scale
done
done
done
done
```
### LibriSpeech BPE training results (Pruned Transducer 2)
[pruned_transducer_stateless2](./pruned_transducer_stateless2)
@ -7,7 +159,7 @@ This is with a reworked version of the conformer encoder, with many changes.
#### Training on fulll librispeech
using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`.
Using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`.
See <https://github.com/k2-fsa/icefall/pull/288>
The WERs are:
@ -33,6 +185,10 @@ and:
The Tensorboard log is at <https://tensorboard.dev/experiment/Xoz0oABMTWewo1slNFXkyA> (apologies, log starts
only from epoch 3).
The pretrained models, training logs, decoding logs, and decoding results
can be found at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>
#### Training on train-clean-100:
@ -350,6 +506,78 @@ You can find a pretrained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01>
##### 2022-04-19
[transducer_stateless2](./transducer_stateless2)
This version uses torchaudio's RNN-T loss.
Using commit `fce7f3cd9a486405ee008bcbe4999264f27774a3`.
See <https://github.com/k2-fsa/icefall/pull/316>
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|--------------------------------------------------------------------------------|
| greedy search (max sym per frame 1) | 2.65 | 6.30 | --epoch 59 --avg 10 --max-duration 600 |
| greedy search (max sym per frame 2) | 2.62 | 6.23 | --epoch 59 --avg 10 --max-duration 100 |
| greedy search (max sym per frame 3) | 2.62 | 6.23 | --epoch 59 --avg 10 --max-duration 100 |
| modified beam search | 2.63 | 6.15 | --epoch 59 --avg 10 --max-duration 100 --decoding-method modified_beam_search |
| beam search | 2.59 | 6.15 | --epoch 59 --avg 10 --max-duration 100 --decoding-method beam_search |
**Note**: This model is trained with standard RNN-T loss. Neither modified transducer nor pruned RNN-T is used.
You can see that there is a performance degradation in WER when we limit the max symbol per frame to 1.
The number of active paths in `modified_beam_search` and `beam_search` is 4.
The training and decoding commands are:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 60 \
--start-epoch 0 \
--exp-dir transducer_stateless2/exp-2 \
--full-libri 1 \
--max-duration 300 \
--lr-factor 5
epoch=59
avg=10
# greedy search
./transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./transducer_stateless2/exp-2 \
--max-duration 600 \
--decoding-method greedy_search \
--max-sym-per-frame 1
# modified beam search
./transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./transducer_stateless2/exp-2 \
--max-duration 100 \
--decoding-method modified_beam_search \
# beam search
./transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./transducer_stateless2/exp-2 \
--max-duration 100 \
--decoding-method beam_search \
```
The tensorboard log is at <https://tensorboard.dev/experiment/oAlle3dxQD2EY8ePwjIGuw/>.
You can find a pre-trained model, decoding logs, and decoding results at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19>
##### 2022-02-07
Using commit `a8150021e01d34ecbd6198fe03a57eacf47a16f2`.

View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
return LG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "LG.pt").is_file():
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir)
logging.info(f"Saving LG.pt to {lang_dir}")
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import torch
from lhotse import (
CutSet,
KaldifeatFbank,
KaldifeatFbankConfig,
)
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_gigaspeech_dev_test():
in_out_dir = Path("data/fbank")
# number of workers in dataloader
num_workers = 20
# number of seconds in a batch
batch_duration = 600
subsets = ("DEV", "TEST")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
for partition in subsets:
cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main():
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_gigaspeech_dev_test()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,169 @@
#!/usr/bin/env python3
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from datetime import datetime
from pathlib import Path
import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-workers",
type=int,
default=20,
help="Number of dataloading workers used for reading the audio.",
)
parser.add_argument(
"--batch-duration",
type=float,
default=600.0,
help="The maximum number of audio seconds in a batch."
"Determines batch size dynamically.",
)
parser.add_argument(
"--num-splits",
type=int,
required=True,
help="The number of splits of the XL subset",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
return parser
def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
output_dir = f"data/fbank/XL_split_{num_splits}"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
num_digits = len(str(num_splits))
start = args.start
stop = args.stop
if stop < start:
stop = num_splits
stop = min(stop, num_splits)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
num_digits = 8 # num_digits is fixed by lhotse split-lazy
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
if not raw_cuts_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it")
continue
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Computing features")
if (output_dir / f"feats_XL_{idx}.lca").exists():
logging.info(f"Removing {output_dir}/feats_XL_{idx}.lca")
os.remove(output_dir / f"feats_XL_{idx}.lca")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/feats_XL_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
)
logging.info("About to split cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main():
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
log_filename = "log-compute_fbank_gigaspeech_splits"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
log_filename = f"{log_filename}-{date_time}"
logging.basicConfig(
filename=log_filename,
format=formatter,
level=logging.INFO,
filemode="w",
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
compute_fbank_gigaspeech_splits(args)
if __name__ == "__main__":
main()

View File

@ -91,21 +91,20 @@ def preprocess_giga_speech():
)
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
logging.info("About to split cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
# if partition not in ["DEV", "TEST"]:
# logging.info(
# f"Speed perturb for {partition} with factors 0.9 and 1.1 "
# "(Perturbing may take 8 minutes and saving may"
# " take 20 minutes)"
# )
# cut_set = (
# cut_set
# + cut_set.perturb_speed(0.9)
# + cut_set.perturb_speed(1.1)
# )
#
# Note: No need to perturb the training subset as not all of the
# data is going to be used in the training.
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file can be used to check if any split is corrupted.
"""
import glob
import re
import lhotse
def main():
d = "data/fbank/XL_split_2000"
filenames = list(glob.glob(f"{d}/cuts_XL.*.jsonl.gz"))
pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = [(int(pattern.search(c).group(1)), c) for c in filenames]
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
print(f"Loading {len(idx_filenames)} splits")
s = 0
for i, f in idx_filenames:
cuts = lhotse.load_manifest_lazy(f)
print(i, "filename", f)
for i, c in enumerate(cuts):
s += c.features.load().shape[0]
if i > 5:
break
if __name__ == "__main__":
main()

View File

@ -0,0 +1,94 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
- Supervision time bounds are within cut time bounds
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/fbank/cuts_train-clean-100.json.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import load_manifest, CutSet
from lhotse.cut import Cut
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def validate_one_supervision_per_cut(c: Cut):
if len(c.supervisions) != 1:
raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
def validate_supervision_and_cut_time_bounds(c: Cut):
s = c.supervisions[0]
if s.start < c.start:
raise ValueError(
f"{c.id}: Supervision start time {s.start} is less "
f"than cut start time {c.start}"
)
if s.end > c.end:
raise ValueError(
f"{c.id}: Supervision end time {s.end} is larger "
f"than cut end time {c.end}"
)
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest(manifest)
assert isinstance(cut_set, CutSet)
for c in cut_set:
validate_one_supervision_per_cut(c)
validate_supervision_and_cut_time_bounds(c)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -118,6 +118,24 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
./local/compute_fbank_librispeech.py
touch data/fbank/.librispeech.done
fi
if [ ! -e data/fbank/.librispeech-validated.done ]; then
log "Validating data/fbank for LibriSpeech"
parts=(
train-clean-100
train-clean-360
train-other-500
test-clean
test-other
dev-clean
dev-other
)
for part in ${parts[@]}; do
python3 ./local/validate_manifest.py \
data/fbank/cuts_${part}.json.gz
done
touch data/fbank/.librispeech-validated.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
@ -242,3 +260,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
./local/compile_hlg.py --lang-dir $lang_dir
done
fi
# Compile LG for RNN-T fast_beam_search decoding
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Compile LG"
./local/compile_lg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir
done
fi

View File

@ -24,6 +24,15 @@ stop_stage=100
# DEV 12 hours
# Test 40 hours
# Split XL subset to this number of pieces
# This is to avoid OOM during feature extraction.
num_splits=2000
# We use lazy split from lhotse.
# The XL subset (10k hours) contains 37956 cuts without speed perturbing.
# We want to split it into 2000 splits, so each split
# contains about 37956 / 2000 = 19 cuts. As a result, there will be 1998 splits.
chunk_size=19 # number of cuts in each split. The last split may contain fewer cuts.
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
@ -107,3 +116,27 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
touch data/fbank/.preprocess_complete
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
python3 ./local/compute_fbank_gigaspeech_dev_test.py
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split XL subset into ${num_splits} pieces"
split_dir=data/fbank/XL_split_${num_splits}
if [ ! -f $split_dir/.split_completed ]; then
lhotse split-lazy ./data/fbank/cuts_XL_raw.jsonl.gz $split_dir $chunk_size
touch $split_dir/.split_completed
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute features for XL"
# Note: The script supports --start and --stop options.
# You can use several machines to compute the features in parallel.
python3 ./local/compute_fbank_gigaspeech_splits.py \
--num-workers $nj \
--batch-duration 600 \
--num-splits $num_splits
fi

View File

@ -22,7 +22,7 @@ import k2
import torch
from model import Transducer
from icefall.decode import one_best_decoding
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
@ -34,6 +34,7 @@ def fast_beam_search(
beam: float,
max_states: int,
max_contexts: int,
use_max: bool = False,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
@ -53,6 +54,9 @@ def fast_beam_search(
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns:
Return the decoded result.
"""
@ -104,9 +108,67 @@ def fast_beam_search(
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
if use_max:
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
else:
num_paths = 200
use_double_scores = True
nbest_scale = 0.8
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores, log_semiring=True
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def greedy_search(
@ -280,7 +342,7 @@ class HypothesisList(object):
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
def add(self, hyp: Hypothesis, use_max: bool = False) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
@ -289,13 +351,20 @@ class HypothesisList(object):
Args:
hyp:
The hypothesis to be added.
use_max:
True to select the hypothesis with the larger log_prob in case there
already exists a hypothesis whose `ys` equals to `hyp.ys`.
False to use log_add.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
if use_max:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
else:
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else:
self._data[key] = hyp
@ -403,6 +472,7 @@ def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -413,6 +483,9 @@ def modified_beam_search(
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
@ -432,7 +505,8 @@ def modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
),
use_max=use_max,
)
for t in range(T):
@ -517,6 +591,7 @@ def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
@ -532,6 +607,9 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns:
Return the decoded result.
"""
@ -553,12 +631,13 @@ def _deprecated_modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
),
use_max=use_max,
)
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
current_encoder_out = encoder_out[:, t:t + 1, :].unsqueeze(2)
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
# fmt: on
A = list(B)
@ -611,7 +690,7 @@ def _deprecated_modified_beam_search(
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
B.add(new_hyp, use_max=use_max)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
@ -623,6 +702,7 @@ def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -636,6 +716,9 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns:
Return the decoded result.
"""
@ -661,7 +744,9 @@ def beam_search(
t = 0
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
B.add(
Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max
)
max_sym_per_utt = 20000
@ -720,7 +805,10 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
B.add(
Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob),
use_max=use_max,
)
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
@ -729,7 +817,10 @@ def beam_search(
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
A.add(
Hypothesis(ys=new_ys, log_prob=new_log_prob),
use_max=use_max,
)
# Check whether B contains more than "beam" elements more probable
# than the most probable in A

View File

@ -53,6 +53,19 @@ Usage:
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) fast beam search using LG
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--use-LG True \
--use-max False \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 8 \
--max-contexts 8 \
--max-states 64
"""
@ -81,10 +94,12 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -136,6 +151,13 @@ def get_parser():
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
@ -152,7 +174,7 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
@ -167,6 +189,36 @@ def get_parser():
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--use-LG",
type=str2bool,
default=False,
help="""Whether to use an LG graph for FSA-based beam search.
Used only when --decoding_method is fast_beam_search. If setting true,
it assumes there is an LG.pt file in lang_dir.""",
)
parser.add_argument(
"--use-max",
type=str2bool,
default=False,
help="""If True, use max-op to select the hypothesis that have the
max log_prob in case of duplicate hypotheses.
If False, use log_add.
Used only for beam_search, modified_beam_search, and fast_beam_search
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
@ -206,6 +258,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -229,6 +282,8 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
@ -260,9 +315,14 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
use_max=params.use_max,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
if params.use_LG:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
else:
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -278,6 +338,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
use_max=params.use_max,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -299,6 +360,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
use_max=params.use_max,
)
else:
raise ValueError(
@ -325,6 +387,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -338,6 +401,8 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
@ -368,8 +433,9 @@ def decode_dataset(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
word_table=word_table,
decoding_graph=decoding_graph,
)
for name, hyps in hyps_dict.items():
@ -460,13 +526,16 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-use-LG-{params.use_LG}"
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-use-max-{params.use_max}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-use-max-{params.use_max}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -527,9 +596,21 @@ def main():
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
if params.use_LG:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
decoding_graph = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/LG.pt", map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -551,6 +632,7 @@ def main():
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)

View File

@ -811,13 +811,23 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
try:
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
logging.info(
f"Before removing short and long utterances: {num_in_total}"
)
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
)
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
@ -21,11 +22,11 @@ import k2
import torch
from model import Transducer
from icefall.decode import one_best_decoding
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
def fast_beam_search_one_best(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
@ -36,6 +37,9 @@ def fast_beam_search(
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
@ -55,6 +59,148 @@ def fast_beam_search(
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.
This is the best result we can achieve for any nbest based rescoring
methods.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
ref_texts:
A list-of-list of integers containing the reference transcripts.
If the decoding_graph is a trivial_graph, the integer ID is the
BPE token ID.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
lattice is actually an acceptor.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
@ -103,9 +249,7 @@ def fast_beam_search(
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
return lattice
def greedy_search(
@ -130,6 +274,7 @@ def greedy_search(
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = model.device
@ -170,7 +315,7 @@ def greedy_search(
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y != blank_id:
if y not in (blank_id, unk_id):
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
@ -211,6 +356,7 @@ def greedy_search_batch(
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
@ -239,7 +385,7 @@ def greedy_search_batch(
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
if v not in (blank_id, unk_id):
hyps[i].append(v)
emitted = True
if emitted:
@ -432,6 +578,7 @@ def modified_beam_search(
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
@ -503,8 +650,10 @@ def modified_beam_search(
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
@ -512,7 +661,7 @@ def modified_beam_search(
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
@ -553,6 +702,7 @@ def _deprecated_modified_beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
@ -614,14 +764,16 @@ def _deprecated_modified_beam_search(
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
@ -658,6 +810,7 @@ def beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
@ -743,7 +896,7 @@ def beam_search(
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id:
if i in (blank_id, unk_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v

View File

@ -69,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -98,27 +98,28 @@ def get_parser():
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
"'--epoch' and '--iter'",
)
parser.add_argument(
@ -151,7 +152,7 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
@ -251,7 +252,7 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -453,13 +454,19 @@ def main():
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -476,8 +483,9 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -485,8 +493,20 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -0,0 +1,306 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
(1) beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by
./pruned_transducer_stateless2/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=8.0,
max_contexts=32,
max_states=8,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550

View File

@ -0,0 +1,314 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import Optional
from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import (
BucketingSampler,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
SpecAugment,
)
from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class AsrDataModule:
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler "
"and DynamicBucketingSampler."
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--on-the-fly-num-workers",
type=int,
default=0,
help="The number of workers for on-the-fly feature extraction",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available. Used only in dev/test CutSet",
)
def train_dataloaders(
self,
cuts_train: CutSet,
dynamic_bucketing: bool,
on_the_fly_feats: bool,
cuts_musan: Optional[CutSet] = None,
) -> DataLoader:
"""
Args:
cuts_train:
Cuts for training.
cuts_musan:
If not None, it is the cuts for mixing.
dynamic_bucketing:
True to use DynamicBucketingSampler;
False to use BucketingSampler.
on_the_fly_feats:
True to use OnTheFlyFeatures;
False to use PrecomputedFeatures.
"""
transforms = []
if cuts_musan is not None:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=(
OnTheFlyFeatures(
extractor=Fbank(FbankConfig(num_mel_bins=80)),
num_workers=self.args.on_the_fly_num_workers,
)
if on_the_fly_feats
else PrecomputedFeatures()
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if dynamic_bucketing:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=True,
)
else:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = BucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/conformer.py

View File

@ -0,0 +1,567 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless3/decode-giga.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless3/decode-giga.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless3/decode-giga.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless3/decode-giga.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from gigaspeech import GigaSpeech
from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
def post_processing(
results: List[Tuple[List[List[str]], List[List[str]]]],
) -> List[Tuple[List[List[str]], List[List[str]]]]:
new_results = []
for ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((new_ref, new_hyp))
return new_results
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = post_processing(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / "giga" / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
# In beam_search.py, we are using model.decoder() and model.joiner(),
# so we have to switch to the branch for the GigaSpeech dataset.
model.decoder = model.decoder_giga
model.joiner = model.joiner_giga
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
asr_datamodule = AsrDataModule(args)
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
test_cuts = gigaspeech.test_cuts()
dev_cuts = gigaspeech.dev_cuts()
test_dl = asr_datamodule.test_dataloaders(test_cuts)
dev_dl = asr_datamodule.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_sets_dl = [test_dl, dev_dl]
for test_set, dl in zip(test_sets, test_sets_dl):
results_dict = decode_dataset(
dl=dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,626 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from librispeech import LibriSpeech
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for computed nbest oracle WER
when the decoding method is fast_beam_search_nbest_oracle.
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding_method is fast_beam_search_nbest_oracle.
""",
)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is
fast_beam_search or fast_beam_search_nbest_oracle.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
elif params.decoding_method == "fast_beam_search_nbest_oracle":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif params.decoding_method == "fast_beam_search_nbest_oracle":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.unk_id()
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
model.unk_id = params.unk_id
if params.decoding_method in (
"fast_beam_search",
"fast_beam_search_nbest_oracle",
):
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
asr_datamodule = AsrDataModule(args)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts)
test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,183 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless3/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless3/decode.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,94 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import re
from pathlib import Path
import lhotse
from lhotse import CutSet, load_manifest
class GigaSpeech:
def __init__(self, manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files::
- XL_split_2000/cuts_XL.*.jsonl.gz
- cuts_L_raw.jsonl.gz
- cuts_M_raw.jsonl.gz
- cuts_S_raw.jsonl.gz
- cuts_XS_raw.jsonl.gz
- cuts_DEV_raw.jsonl.gz
- cuts_TEST_raw.jsonl.gz
"""
self.manifest_dir = Path(manifest_dir)
def train_XL_cuts(self) -> CutSet:
logging.info("About to get train-XL cuts")
filenames = list(
glob.glob(f"{self.manifest_dir}/XL_split_2000/cuts_XL.*.jsonl.gz")
)
pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz")
idx_filenames = [
(int(pattern.search(f).group(1)), f) for f in filenames
]
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading {len(sorted_filenames)} splits")
return lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
def train_L_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_L_raw.jsonl.gz"
logging.info(f"About to get train-L cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def train_M_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_M_raw.jsonl.gz"
logging.info(f"About to get train-M cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def train_S_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_S_raw.jsonl.gz"
logging.info(f"About to get train-S cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def train_XS_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_XS_raw.jsonl.gz"
logging.info(f"About to get train-XS cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def test_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_TEST.jsonl.gz"
logging.info(f"About to get TEST cuts from {f}")
return load_manifest(f)
def dev_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_DEV.jsonl.gz"
logging.info(f"About to get DEV cuts from {f}")
return load_manifest(f)

View File

@ -0,0 +1 @@
../../../gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,74 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
class LibriSpeech:
def __init__(self, manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files::
- cuts_dev-clean.json.gz
- cuts_dev-other.json.gz
- cuts_test-clean.json.gz
- cuts_test-other.json.gz
- cuts_train-clean-100.json.gz
- cuts_train-clean-360.json.gz
- cuts_train-other-500.json.gz
"""
self.manifest_dir = Path(manifest_dir)
def train_clean_100_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_train-clean-100.json.gz"
logging.info(f"About to get train-clean-100 cuts from {f}")
return load_manifest(f)
def train_clean_360_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_train-clean-360.json.gz"
logging.info(f"About to get train-clean-360 cuts from {f}")
return load_manifest(f)
def train_other_500_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_train-other-500.json.gz"
logging.info(f"About to get train-other-500 cuts from {f}")
return load_manifest(f)
def test_clean_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_test-clean.json.gz"
logging.info(f"About to get test-clean cuts from {f}")
return load_manifest(f)
def test_other_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_test-other.json.gz"
logging.info(f"About to get test-other cuts from {f}")
return load_manifest(f)
def dev_clean_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_dev-clean.json.gz"
logging.info(f"About to get dev-clean cuts from {f}")
return load_manifest(f)
def dev_other_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_dev-other.json.gz"
logging.info(f"About to get dev-other cuts from {f}")
return load_manifest(f)

View File

@ -0,0 +1,235 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
decoder_giga: Optional[nn.Module] = None,
joiner_giga: Optional[nn.Module] = None,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim). Its output shape is (N, T, U, vocab_size).
Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
encoder_dim:
Output dimension of the encoder network.
decoder_dim:
Output dimension of the decoder network.
joiner_dim:
Input dimension of the joiner network.
vocab_size:
Output dimension of the joiner network.
decoder_giga:
Optional. The decoder network for the GigaSpeech dataset.
joiner_giga:
Optional. The joiner network for the GigaSpeech dataset.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_giga is not None:
self.simple_am_proj_giga = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj_giga = ScaledLinear(decoder_dim, vocab_size)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
libri: bool = True,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
libri:
True to use the decoder and joiner for the LibriSpeech dataset.
False to use the decoder and joiner for the GigaSpeech dataset.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
assert torch.all(encoder_out_lens > 0)
if libri:
decoder = self.decoder
simple_lm_proj = self.simple_lm_proj
simple_am_proj = self.simple_am_proj
joiner = self.joiner
else:
decoder = self.decoder_giga
simple_lm_proj = self.simple_lm_proj_giga
simple_am_proj = self.simple_am_proj_giga
joiner = self.joiner_giga
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=joiner.encoder_proj(encoder_out),
lm=joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,306 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
(1) beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./pruned_transducer_stateless3/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless3/exp/pretrained.pt is generated by
./pruned_transducer_stateless3/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=8.0,
max_contexts=32,
max_states=8,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -89,9 +89,9 @@ class Decoder(nn.Module):
- (h, c), containing the state information for LSTM layers.
Both are of shape (num_layers, N, C)
"""
embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states)
embedding_out = self.embedding(y)
embedding_out = self.embedding_dropout(embedding_out)
rnn_out, (h, c) = self.rnn(embedding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -93,9 +93,9 @@ class Decoder(nn.Module):
- (h, c), containing the state information for LSTM layers.
Both are of shape (num_layers, N, C)
"""
embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states)
embedding_out = self.embedding(y)
embedding_out = self.embedding_dropout(embedding_out)
rnn_out, (h, c) = self.rnn(embedding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
@ -505,8 +506,10 @@ def modified_beam_search(
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
@ -613,8 +616,10 @@ def _deprecated_modified_beam_search(
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]

View File

@ -653,13 +653,23 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
try:
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
logging.info(
f"Before removing short and long utterances: {num_in_total}"
)
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
)
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts)

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -0,0 +1 @@
../transducer_stateless/beam_search.py

View File

@ -0,0 +1 @@
../transducer_stateless/conformer.py

View File

@ -0,0 +1,443 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless2/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_stateless2/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_stateless2/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=29,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=13,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyp_list: List[List[int]] = []
if (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
else:
return {f"beam_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../transducer_stateless/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,181 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer_stateless2/export.py \
--exp-dir ./transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer_stateless2/decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./transducer_stateless2/decode.py \
--exp-dir ./transducer_stateless2/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=20,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless2/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,67 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
*unused,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, self.input_dim).
decoder_out:
Output from the decoder. Its shape is (N, U, self.input_dim).
unused:
This is a placeholder so that we can reuse
transducer_stateless/beam_search.py in this folder as that
script assumes the joiner networks accepts 4 inputs.
Returns:
Return a tensor of shape (N, T, U, self.output_dim).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == self.input_dim
assert decoder_out.size(2) == self.input_dim
encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C)
x = encoder_out + decoder_out # (N, T, U, C)
activations = torch.tanh(x)
logits = self.output_linear(activations)
if not self.training:
# We reuse the beam_search.py from transducer_stateless,
# which expects that the joiner network outputs
# a 2-D tensor.
logits = logits.squeeze(2).squeeze(1)
return logits

View File

@ -0,0 +1,130 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(
encoder_out=encoder_out,
decoder_out=decoder_out,
)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
)
return loss

View File

@ -0,0 +1,293 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav \
(2) beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
(3) modified beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./transducer_stateless2/exp/epoch-xx.pt`.
Note: ./transducer_stateless2/exp/pretrained.pt is generated by
./transducer_stateless2/export.py
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyp_list = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../transducer_stateless/subsampling.py

View File

@ -0,0 +1,779 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless2/exp \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
"""
import argparse
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless2/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lr-factor",
type=float,
default=5.0,
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(
x=feature,
x_lens=feature_lens,
y=y,
)
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 800
params.warm_step = 8000
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
try:
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(
f"Before removing short and long utterances: {num_in_total}"
)
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
)
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../transducer_stateless/transformer.py

View File

@ -84,9 +84,9 @@ class Decoder(nn.Module):
- (h, c), which contain the state information for RNN layers.
Both are of shape (num_layers, N, C)
"""
embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states)
embedding_out = self.embedding(y)
embedding_out = self.embedding_dropout(embedding_out)
rnn_out, (h, c) = self.rnn(embedding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -150,12 +150,25 @@ def average_checkpoints(
n = len(filenames)
avg = torch.load(filenames[0], map_location=device)["model"]
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
state_dict = torch.load(filenames[i], map_location=device)["model"]
for k in avg:
for k in uniqued_names:
avg[k] += state_dict[k]
for k in avg:
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else:

View File

@ -11,7 +11,7 @@ graphviz==0.19.1
-f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu
-f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu
-f https://k2-fsa.org/nightly/ k2==1.14.dev20220316+cpu.torch1.10.0
-f https://k2-fsa.org/nightly/ k2==1.15.1.dev20220426+cpu.torch1.10.0
git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11