mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-14 19:44:21 +00:00
Merge branch 'k2-fsa:master' into tiny
This commit is contained in:
commit
49c1159684
1
.flake8
1
.flake8
@ -13,6 +13,7 @@ per-file-ignores =
|
||||
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
|
||||
egs/librispeech/ASR/conformer_ctc*/*py: E501,
|
||||
egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
|
||||
egs/librispeech/ASR/zipformer/*.py: E501, E203
|
||||
egs/librispeech/ASR/RESULTS.md: E999,
|
||||
|
||||
# invalid escape sequence (cause by tex formular), W605
|
||||
|
@ -21,9 +21,9 @@ tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/L.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/LG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/Linv.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/cpu_jit.pt"
|
116
.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
vendored
Executable file
116
.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
vendored
Executable file
@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
||||
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
ls -lh *.pt
|
||||
popd
|
||||
|
||||
log "Export to torchscript model"
|
||||
./zipformer/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model false \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with models exported by torch.jit.script()"
|
||||
|
||||
./zipformer/jit_pretrained_streaming.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
for method in greedy_search modified_beam_search fast_beam_search; do
|
||||
log "$method"
|
||||
|
||||
./zipformer/pretrained.py \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||
mkdir -p zipformer/exp
|
||||
ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
|
||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||
|
||||
ls -lh data
|
||||
ls -lh zipformer/exp
|
||||
|
||||
log "Decoding test-clean and test-other"
|
||||
|
||||
# use a small value for decoding with CPU
|
||||
max_duration=100
|
||||
|
||||
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||
log "Simulated streaming decoding with $method"
|
||||
|
||||
./zipformer/decode.py \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method $method \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
--max-duration $max_duration \
|
||||
--exp-dir zipformer/exp
|
||||
done
|
||||
|
||||
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||
log "Chunk-wise streaming decoding with $method"
|
||||
|
||||
./zipformer/streaming_decode.py \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method $method \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
--max-duration $max_duration \
|
||||
--exp-dir zipformer/exp
|
||||
done
|
||||
|
||||
rm zipformer/exp/*.pt
|
||||
fi
|
94
.github/scripts/run-librispeech-zipformer-2023-05-18.sh
vendored
Executable file
94
.github/scripts/run-librispeech-zipformer-2023-05-18.sh
vendored
Executable file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
||||
git lfs pull --include "exp/jit_script.pt"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
ls -lh *.pt
|
||||
popd
|
||||
|
||||
log "Export to torchscript model"
|
||||
./zipformer/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model false \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with models exported by torch.jit.script()"
|
||||
|
||||
./zipformer/jit_pretrained.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--nn-model-filename $repo/exp/jit_script.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
for method in greedy_search modified_beam_search fast_beam_search; do
|
||||
log "$method"
|
||||
|
||||
./zipformer/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||
mkdir -p zipformer/exp
|
||||
ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
|
||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||
|
||||
ls -lh data
|
||||
ls -lh zipformer/exp
|
||||
|
||||
log "Decoding test-clean and test-other"
|
||||
|
||||
# use a small value for decoding with CPU
|
||||
max_duration=100
|
||||
|
||||
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||
log "Decoding with $method"
|
||||
|
||||
./zipformer/decode.py \
|
||||
--decoding-method $method \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
--max-duration $max_duration \
|
||||
--exp-dir zipformer/exp
|
||||
done
|
||||
|
||||
rm zipformer/exp/*.pt
|
||||
fi
|
117
.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
vendored
Executable file
117
.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
vendored
Executable file
@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/L.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/LG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/Linv.pt"
|
||||
git lfs pull --include "data/lm/G_4_gram.pt"
|
||||
git lfs pull --include "exp/jit_script.pt"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
ls -lh *.pt
|
||||
popd
|
||||
|
||||
log "Export to torchscript model"
|
||||
./zipformer/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--use-transducer 1 \
|
||||
--use-ctc 1 \
|
||||
--use-averaged-model false \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with models exported by torch.jit.script()"
|
||||
|
||||
for method in ctc-decoding 1best; do
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--model-filename $repo/exp/jit_script.pt \
|
||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||
--G $repo/data/lm/G_4_gram.pt \
|
||||
--method $method \
|
||||
--sample-rate 16000 \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
for method in ctc-decoding 1best; do
|
||||
log "$method"
|
||||
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--use-transducer 1 \
|
||||
--use-ctc 1 \
|
||||
--method $method \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||
--G $repo/data/lm/G_4_gram.pt \
|
||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||
--sample-rate 16000 \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||
mkdir -p zipformer/exp
|
||||
ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
|
||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||
|
||||
ls -lh data
|
||||
ls -lh zipformer/exp
|
||||
|
||||
log "Decoding test-clean and test-other"
|
||||
|
||||
# use a small value for decoding with CPU
|
||||
max_duration=100
|
||||
|
||||
for method in ctc-decoding 1best; do
|
||||
log "Decoding with $method"
|
||||
|
||||
./zipformer/ctc_decode.py \
|
||||
--use-transducer 1 \
|
||||
--use-ctc 1 \
|
||||
--decoding-method $method \
|
||||
--nbest-scale 1.0 \
|
||||
--hlg-scale 0.6 \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--use-averaged-model 0 \
|
||||
--max-duration $max_duration \
|
||||
--exp-dir zipformer/exp
|
||||
done
|
||||
|
||||
rm zipformer/exp/*.pt
|
||||
fi
|
4
.github/scripts/test-ncnn-export.sh
vendored
4
.github/scripts/test-ncnn-export.sh
vendored
@ -195,14 +195,14 @@ git lfs pull --include "data/lang_char_bpe/Linv.pt"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
popd
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||
--lang-dir $repo/data/lang_char_bpe \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--decode-chunk-len 32 \
|
||||
--num-encoder-layers "2,4,3,2,4" \
|
||||
|
6
.github/workflows/run-aishell-2022-06-20.yml
vendored
6
.github/workflows/run-aishell-2022-06-20.yml
vendored
@ -44,7 +44,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -73,7 +73,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -119,5 +119,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-06-20
|
||||
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-06-20
|
||||
path: egs/aishell/ASR/pruned_transducer_stateless3/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -122,5 +122,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
|
||||
path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless-2022-03-12
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -174,12 +174,12 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
||||
- name: Upload decoding results for pruned_transducer_stateless3
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless5-2022-05-13
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless5-2022-05-13
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless5/exp/
|
||||
|
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-2022-11-11
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/
|
||||
|
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless8-2022-11-14
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless8-2022-11-14
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/
|
||||
|
@ -68,7 +68,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-2022-12-01
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/
|
||||
|
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -163,5 +163,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer_mmi-2022-12-08
|
||||
path: egs/librispeech/ASR/zipformer_mmi/exp/
|
||||
|
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -168,5 +168,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-streaming-2022-12-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-streaming-2022-12-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-librispeech-2022-12-15-stateless7-ctc-bs
|
||||
name: run-librispeech-2023-01-29-stateless7-ctc-bs
|
||||
# zipformer
|
||||
|
||||
on:
|
||||
@ -34,7 +34,7 @@ on:
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
jobs:
|
||||
run_librispeech_2022_12_15_zipformer_ctc_bs:
|
||||
run_librispeech_2023_01_29_zipformer_ctc_bs:
|
||||
if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
@ -68,7 +68,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -124,7 +124,7 @@ jobs:
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh
|
||||
.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh
|
||||
|
||||
- name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2023-01-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/
|
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -151,5 +151,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-conformer_ctc3-2022-11-28
|
||||
path: egs/librispeech/ASR/conformer_ctc3/exp/
|
||||
|
@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
@ -55,7 +55,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03
|
||||
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -153,5 +153,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-06-26
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-06-26
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
174
.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml
vendored
Normal file
174
.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml
vendored
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-librispeech-streaming-zipformer-2023-05-18
|
||||
# zipformer
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
schedule:
|
||||
# minute (0-59)
|
||||
# hour (0-23)
|
||||
# day of the month (1-31)
|
||||
# month (1-12)
|
||||
# day of the week (0-6)
|
||||
# nightly build at 15:50 UTC time every day
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
concurrency:
|
||||
group: run_librispeech_2023_05_18_streaming_zipformer-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_librispeech_2023_05_18_streaming_zipformer:
|
||||
if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||
id: libri-test-clean-and-test-other-data
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/download
|
||||
key: cache-libri-test-clean-and-test-other
|
||||
|
||||
- name: Download LibriSpeech test-clean and test-other
|
||||
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||
|
||||
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||
|
||||
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||
id: libri-test-clean-and-test-other-fbank
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/fbank-libri
|
||||
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||
|
||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
mkdir -p egs/librispeech/ASR/data
|
||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||
ls -lh egs/librispeech/ASR/data/*
|
||||
|
||||
sudo apt-get -qq install git-lfs tree
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
|
||||
|
||||
- name: Display decoding results for librispeech zipformer
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
shell: bash
|
||||
run: |
|
||||
cd egs/librispeech/ASR/
|
||||
tree ./zipformer/exp
|
||||
|
||||
cd zipformer
|
||||
|
||||
echo "results for zipformer, simulated streaming decoding"
|
||||
echo "===greedy search==="
|
||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===fast_beam_search==="
|
||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===modified beam search==="
|
||||
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "results for zipformer, chunk-wise streaming decoding"
|
||||
echo "===greedy search==="
|
||||
find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===fast_beam_search==="
|
||||
find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===modified beam search==="
|
||||
find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
|
||||
- name: Upload decoding results for librispeech zipformer
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless2-2022-04-19
|
||||
path: egs/librispeech/ASR/transducer_stateless2/exp/
|
||||
|
159
.github/workflows/run-librispeech-zipformer-2023-05-18.yml
vendored
Normal file
159
.github/workflows/run-librispeech-zipformer-2023-05-18.yml
vendored
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-librispeech-zipformer-2023-05-18
|
||||
# zipformer
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
schedule:
|
||||
# minute (0-59)
|
||||
# hour (0-23)
|
||||
# day of the month (1-31)
|
||||
# month (1-12)
|
||||
# day of the week (0-6)
|
||||
# nightly build at 15:50 UTC time every day
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
concurrency:
|
||||
group: run_librispeech_2023_05_18_zipformer-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_librispeech_2023_05_18_zipformer:
|
||||
if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||
id: libri-test-clean-and-test-other-data
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/download
|
||||
key: cache-libri-test-clean-and-test-other
|
||||
|
||||
- name: Download LibriSpeech test-clean and test-other
|
||||
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||
|
||||
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||
|
||||
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||
id: libri-test-clean-and-test-other-fbank
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/fbank-libri
|
||||
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||
|
||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
mkdir -p egs/librispeech/ASR/data
|
||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||
ls -lh egs/librispeech/ASR/data/*
|
||||
|
||||
sudo apt-get -qq install git-lfs tree
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-zipformer-2023-05-18.sh
|
||||
|
||||
- name: Display decoding results for librispeech zipformer
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
shell: bash
|
||||
run: |
|
||||
cd egs/librispeech/ASR/
|
||||
tree ./zipformer/exp
|
||||
|
||||
cd zipformer
|
||||
echo "results for zipformer"
|
||||
echo "===greedy search==="
|
||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===fast_beam_search==="
|
||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===modified beam search==="
|
||||
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
- name: Upload decoding results for librispeech zipformer
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
155
.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml
vendored
Normal file
155
.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml
vendored
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-librispeech-zipformer-ctc-2023-06-14
|
||||
# zipformer
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
schedule:
|
||||
# minute (0-59)
|
||||
# hour (0-23)
|
||||
# day of the month (1-31)
|
||||
# month (1-12)
|
||||
# day of the week (0-6)
|
||||
# nightly build at 15:50 UTC time every day
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
concurrency:
|
||||
group: run_librispeech_2023_06_14_zipformer-ctc-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
run_librispeech_2023_06_14_zipformer_ctc:
|
||||
if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/install-kaldifeat.sh
|
||||
|
||||
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||
id: libri-test-clean-and-test-other-data
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/download
|
||||
key: cache-libri-test-clean-and-test-other
|
||||
|
||||
- name: Download LibriSpeech test-clean and test-other
|
||||
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||
|
||||
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||
|
||||
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||
id: libri-test-clean-and-test-other-fbank
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/fbank-libri
|
||||
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||
|
||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
env:
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||
run: |
|
||||
mkdir -p egs/librispeech/ASR/data
|
||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||
ls -lh egs/librispeech/ASR/data/*
|
||||
|
||||
sudo apt-get -qq install git-lfs tree
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
|
||||
|
||||
- name: Display decoding results for librispeech zipformer
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
shell: bash
|
||||
run: |
|
||||
cd egs/librispeech/ASR/
|
||||
tree ./zipformer/exp
|
||||
|
||||
cd zipformer
|
||||
echo "results for zipformer"
|
||||
echo "===ctc-decoding==="
|
||||
find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
echo "===1best==="
|
||||
find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||
find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||
|
||||
- name: Upload decoding results for librispeech zipformer
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -71,7 +71,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
|
||||
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -71,7 +71,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
|
||||
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -71,7 +71,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless-2022-02-07
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless-2022-02-07
|
||||
path: egs/librispeech/ASR/transducer_stateless/exp/
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
|
6
.github/workflows/run-yesno-recipe.yml
vendored
6
.github/workflows/run-yesno-recipe.yml
vendored
@ -33,9 +33,9 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# os: [ubuntu-18.04, macos-10.15]
|
||||
# os: [ubuntu-latest, macos-10.15]
|
||||
# TODO: enable macOS for CPU testing
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
fail-fast: false
|
||||
|
||||
@ -69,6 +69,8 @@ jobs:
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf==3.20.*
|
||||
|
||||
pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl
|
||||
|
||||
- name: Run yesno recipe
|
||||
shell: bash
|
||||
working-directory: ${{github.workspace}}
|
||||
|
2
.github/workflows/test-ncnn-export.yml
vendored
2
.github/workflows/test-ncnn-export.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
|
2
.github/workflows/test-onnx-export.yml
vendored
2
.github/workflows/test-onnx-export.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}-2022-09-25
|
||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
|
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
@ -113,6 +113,7 @@ jobs:
|
||||
cd ../pruned_transducer_stateless4
|
||||
pytest -v -s
|
||||
|
||||
echo $PYTHONPATH
|
||||
cd ../pruned_transducer_stateless7
|
||||
pytest -v -s
|
||||
|
||||
|
@ -26,7 +26,7 @@ repos:
|
||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.10.1
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile=black"]
|
||||
|
159
README.md
159
README.md
@ -28,14 +28,15 @@ We provide the following recipes:
|
||||
|
||||
- [yesno][yesno]
|
||||
- [LibriSpeech][librispeech]
|
||||
- [GigaSpeech][gigaspeech]
|
||||
- [Aishell][aishell]
|
||||
- [Aishell2][aishell2]
|
||||
- [Aishell4][aishell4]
|
||||
- [TIMIT][timit]
|
||||
- [TED-LIUM3][tedlium3]
|
||||
- [GigaSpeech][gigaspeech]
|
||||
- [Aidatatang_200zh][aidatatang_200zh]
|
||||
- [WenetSpeech][wenetspeech]
|
||||
- [Alimeeting][alimeeting]
|
||||
- [Aishell4][aishell4]
|
||||
- [TAL_CSASR][tal_csasr]
|
||||
|
||||
### yesno
|
||||
@ -46,9 +47,7 @@ Training takes less than 30 seconds and gives you the following WER:
|
||||
```
|
||||
[test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
```
|
||||
We do provide a Colab notebook for this recipe.
|
||||
|
||||
[](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
|
||||
We provide a Colab notebook for this recipe: [](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
|
||||
|
||||
|
||||
### LibriSpeech
|
||||
@ -56,12 +55,13 @@ We do provide a Colab notebook for this recipe.
|
||||
Please see <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>
|
||||
for the **latest** results.
|
||||
|
||||
We provide 4 models for this recipe:
|
||||
We provide 5 models for this recipe:
|
||||
|
||||
- [conformer CTC model][LibriSpeech_conformer_ctc]
|
||||
- [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc]
|
||||
- [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer]
|
||||
- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless]
|
||||
- [Transducer: Zipformer encoder + Embedding decoder][LibriSpeech_zipformer]
|
||||
|
||||
#### Conformer CTC Model
|
||||
|
||||
@ -116,21 +116,58 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
|
||||
|
||||
#### k2 pruned RNN-T
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 2.57 | 5.95 |
|
||||
| Encoder | Params | test-clean | test-other |
|
||||
|-----------------|--------|------------|------------|
|
||||
| zipformer | 65.5M | 2.21 | 4.91 |
|
||||
| zipformer-small | 23.2M | 2.46 | 5.83 |
|
||||
| zipformer-large | 148.4M | 2.11 | 4.77 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
|
||||
#### k2 pruned RNN-T + GigaSpeech
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 2.00 | 4.63 |
|
||||
| WER | 1.78 | 4.08 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
|
||||
#### k2 pruned RNN-T + GigaSpeech + CommonVoice
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 1.90 | 3.98 |
|
||||
|
||||
Note: No auxiliary losses are used in the training and no LMs are used
|
||||
in the decoding.
|
||||
|
||||
|
||||
### GigaSpeech
|
||||
|
||||
We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
|
||||
and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
|
||||
|
||||
#### Conformer CTC
|
||||
|
||||
| | Dev | Test |
|
||||
|-----|-------|-------|
|
||||
| WER | 10.47 | 10.58 |
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
| | Dev | Test |
|
||||
|----------------------|-------|-------|
|
||||
| greedy search | 10.51 | 10.73 |
|
||||
| fast beam search | 10.50 | 10.69 |
|
||||
| modified beam search | 10.40 | 10.51 |
|
||||
|
||||
|
||||
### Aishell
|
||||
|
||||
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]
|
||||
and [TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc].
|
||||
We provide three models for this recipe: [conformer CTC model][Aishell_conformer_ctc],
|
||||
[TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc], and [Transducer Stateless Model][Aishell_pruned_transducer_stateless7],
|
||||
|
||||
#### Conformer CTC Model
|
||||
|
||||
@ -140,20 +177,6 @@ The best CER we currently have is:
|
||||
|-----|------|
|
||||
| CER | 4.26 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained conformer CTC model: [
|
||||
|
||||
#### Transducer Stateless Model
|
||||
|
||||
The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------|
|
||||
| CER | 4.68 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
|
||||
|
||||
#### TDNN LSTM CTC Model
|
||||
|
||||
The CER for this model is:
|
||||
@ -164,6 +187,46 @@ The CER for this model is:
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
|
||||
|
||||
#### Transducer Stateless Model
|
||||
|
||||
The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------|
|
||||
| CER | 4.38 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
|
||||
|
||||
|
||||
### Aishell2
|
||||
|
||||
We provide one model for this recipe: [Transducer Stateless Model][Aishell2_pruned_transducer_stateless5].
|
||||
|
||||
#### Transducer Stateless Model
|
||||
|
||||
The best WER we currently have is:
|
||||
|
||||
| | dev-ios | test-ios |
|
||||
|-----|------------|------------|
|
||||
| WER | 5.32 | 5.56 |
|
||||
|
||||
|
||||
### Aishell4
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
|
||||
|
||||
The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------------|
|
||||
| CER | 29.08 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
|
||||
|
||||
|
||||
### TIMIT
|
||||
|
||||
We provide two models for this recipe: [TDNN LSTM CTC model][TIMIT_tdnn_lstm_ctc]
|
||||
@ -187,7 +250,8 @@ The PER for this model is:
|
||||
|--|--|
|
||||
|PER| 17.66% |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [](https://colab.research.google.com/drive/11IT-k4HQIgQngXz1uvWsEYktjqQt7Tmb?usp=sharing)
|
||||
We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
|
||||
|
||||
|
||||
### TED-LIUM3
|
||||
|
||||
@ -215,24 +279,6 @@ The best WER using modified beam search with beam size 4 is:
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing)
|
||||
|
||||
### GigaSpeech
|
||||
|
||||
We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
|
||||
and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
|
||||
|
||||
#### Conformer CTC
|
||||
|
||||
| | Dev | Test |
|
||||
|-----|-------|-------|
|
||||
| WER | 10.47 | 10.58 |
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
| | Dev | Test |
|
||||
|----------------------|-------|-------|
|
||||
| greedy search | 10.51 | 10.73 |
|
||||
| fast beam search | 10.50 | 10.69 |
|
||||
| modified beam search | 10.40 | 10.51 |
|
||||
|
||||
### Aidatatang_200zh
|
||||
|
||||
@ -248,6 +294,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
|
||||
|
||||
|
||||
### WenetSpeech
|
||||
|
||||
We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
|
||||
@ -284,20 +331,6 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
|
||||
|
||||
### Aishell4
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
|
||||
|
||||
The best CER(%) results:
|
||||
| | test |
|
||||
|----------------------|--------|
|
||||
| greedy search | 29.89 |
|
||||
| fast beam search | 28.91 |
|
||||
| modified beam search | 29.08 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
|
||||
|
||||
### TAL_CSASR
|
||||
|
||||
@ -331,8 +364,12 @@ Please see: [. However,
|
||||
you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models
|
||||
using that corpus.
|
||||
|
||||
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
|
||||
to address the language information mismatch between the training
|
||||
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
|
||||
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{score}\left(y_u|\mathit{x},y\right) =
|
||||
\log p\left(y_u|\mathit{x},y_{1:u-1}\right) +
|
||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||
\lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||
|
||||
|
||||
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
|
||||
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
|
||||
shallow fusion is the subtraction of the source domain LM.
|
||||
|
||||
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
|
||||
considered to be weak and can only capture low-level language information. Therefore, `LODR <https://arxiv.org/abs/2203.16776>`__ proposed to use
|
||||
a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula
|
||||
during decoding for transducer model:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{score}\left(y_u|\mathit{x},y\right) =
|
||||
\log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) +
|
||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||
|
||||
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
|
||||
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
||||
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
||||
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
||||
|
||||
Now, we will show you how to use LODR in ``icefall``.
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
The testing scenario here is intra-domain (we decode the model trained on `LibriSpeech`_ on `LibriSpeech`_ testing sets).
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
Then, we download the external language model and bi-gram LM that are necessary for LODR.
|
||||
Note that the bi-gram is estimated on the LibriSpeech 960 hours' text.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
$
|
||||
$ # download the bi-gram
|
||||
$ git lfs install
|
||||
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
|
||||
$ pushd data/lang_bpe_500
|
||||
$ ln -s ../../librispeech_bigram/2gram.fst.txt .
|
||||
$ popd
|
||||
|
||||
Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.42
|
||||
$ LODR_scale=-0.24
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_LODR \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--use-shallow-fusion 1 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500 \
|
||||
--tokens-ngram 2 \
|
||||
--ngram-lm-scale $LODR_scale
|
||||
|
||||
There are two extra arguments that need to be given when doing LODR. ``--tokens-ngram`` specifies the order of n-gram. As we
|
||||
are using a bi-gram, we set it to 2. ``--ngram-lm-scale`` is the scale of the bi-gram, it should be a negative number
|
||||
as we are subtracting the bi-gram's score during decoding.
|
||||
|
||||
The decoding results obtained with the above command are shown below:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.61 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 6.74 best for test-other
|
||||
|
||||
Recall that the lowest WER we obtained in :ref:`shallow_fusion` with beam size of 4 is ``2.77/7.08``, LODR
|
||||
indeed **further improves** the WER. We can do even better if we increase ``--beam-size``:
|
||||
|
||||
.. list-table:: WER of LODR with different beam sizes
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.61
|
||||
- 6.74
|
||||
* - 8
|
||||
- 2.45
|
||||
- 6.38
|
||||
* - 12
|
||||
- 2.4
|
||||
- 6.23
|
12
docs/source/decoding-with-langugage-models/index.rst
Normal file
12
docs/source/decoding-with-langugage-models/index.rst
Normal file
@ -0,0 +1,12 @@
|
||||
Decoding with language models
|
||||
=============================
|
||||
|
||||
This section describes how to use external langugage models
|
||||
during decoding to improve the WER of transducer models.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
shallow-fusion
|
||||
LODR
|
||||
rescoring
|
252
docs/source/decoding-with-langugage-models/rescoring.rst
Normal file
252
docs/source/decoding-with-langugage-models/rescoring.rst
Normal file
@ -0,0 +1,252 @@
|
||||
.. _rescoring:
|
||||
|
||||
LM rescoring for Transducer
|
||||
=================================
|
||||
|
||||
LM rescoring is a commonly used approach to incorporate external LM information. Unlike shallow-fusion-based
|
||||
methods (see :ref:`shallow-fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search.
|
||||
Rescoring is usually more efficient than shallow fusion since less computation is performed on the external LM.
|
||||
In this tutorial, we will show you how to use external LM to rescore the n-best hypotheses decoded from neural transducer models in
|
||||
`icefall <https://github.com/k2-fsa/icefall>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the recipe
|
||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||
However, you can easily apply shallow fusion to other recipes.
|
||||
If you encounter any problems, please open an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
|
||||
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU for decoding.
|
||||
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
As usual, we first test the model's performance without external LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
Now, we will try to improve the above WER numbers via external LM rescoring. We will download
|
||||
a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
|
||||
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
|
||||
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
|
||||
|
||||
With the RNNLM available, we can rescore the n-best hypotheses generated from `modified_beam_search`. Here,
|
||||
`n` should be the number of beams, i.e ``--beam-size``. The command for LM rescoring is
|
||||
as follows. Note that the ``--decoding-method`` is set to `modified_beam_search_lm_rescore` and ``--use-shallow-fusion``
|
||||
is set to `False`.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.43
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_rescore \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--use-shallow-fusion 0 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.93 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.6 best for test-other
|
||||
|
||||
Great! We made some improvements! Increasing the size of the n-best hypotheses will further boost the performance,
|
||||
see the following table:
|
||||
|
||||
.. list-table:: WERs of LM rescoring with different beam sizes
|
||||
:widths: 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.93
|
||||
- 7.6
|
||||
* - 8
|
||||
- 2.67
|
||||
- 7.11
|
||||
* - 12
|
||||
- 2.59
|
||||
- 6.86
|
||||
|
||||
In fact, we can also apply LODR (see :ref:`LODR`) when doing LM rescoring. To do so, we need to
|
||||
download the bi-gram required by LODR:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the bi-gram
|
||||
$ git lfs install
|
||||
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
|
||||
$ pushd data/lang_bpe_500
|
||||
$ ln -s ../../librispeech_bigram/2gram.arpa .
|
||||
$ popd
|
||||
|
||||
Then we can performn LM rescoring + LODR by changing the decoding method to `modified_beam_search_lm_rescore_LODR`.
|
||||
|
||||
.. note::
|
||||
|
||||
This decoding method requires the dependency of `kenlm <https://github.com/kpu/kenlm>`_. You can install it
|
||||
via this command: `pip install https://github.com/kpu/kenlm/archive/master.zip`.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.43
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_rescore_LODR \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--use-shallow-fusion 0 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
You should see the following WERs after executing the commands above:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.9 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.57 best for test-other
|
||||
|
||||
It's slightly better than LM rescoring. If we further increase the beam size, we will see
|
||||
further improvements from LM rescoring + LODR:
|
||||
|
||||
.. list-table:: WERs of LM rescoring + LODR with different beam sizes
|
||||
:widths: 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.9
|
||||
- 7.57
|
||||
* - 8
|
||||
- 2.63
|
||||
- 7.04
|
||||
* - 12
|
||||
- 2.52
|
||||
- 6.73
|
||||
|
||||
As mentioned earlier, LM rescoring is usually faster than shallow-fusion based methods.
|
||||
Here, we benchmark the WERs and decoding speed of them:
|
||||
|
||||
.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean)
|
||||
:widths: 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Decoding method
|
||||
- beam=4
|
||||
- beam=8
|
||||
- beam=12
|
||||
* - `modified_beam_search`
|
||||
- 3.11/7.93; 132s
|
||||
- 3.1/7.95; 177s
|
||||
- 3.1/7.96; 210s
|
||||
* - `modified_beam_search_lm_shallow_fusion`
|
||||
- 2.77/7.08; 262s
|
||||
- 2.62/6.65; 352s
|
||||
- 2.58/6.65; 488s
|
||||
* - LODR
|
||||
- 2.61/6.74; 400s
|
||||
- 2.45/6.38; 610s
|
||||
- 2.4/6.23; 870s
|
||||
* - `modified_beam_search_lm_rescore`
|
||||
- 2.93/7.6; 156s
|
||||
- 2.67/7.11; 203s
|
||||
- 2.59/6.86; 255s
|
||||
* - `modified_beam_search_lm_rescore_LODR`
|
||||
- 2.9/7.57; 160s
|
||||
- 2.63/7.04; 203s
|
||||
- 2.52/6.73; 263s
|
||||
|
||||
.. note::
|
||||
|
||||
Decoding is performed with a single 32G V100, we set ``--max-duration`` to 600.
|
||||
Decoding time here is only for reference and it may vary.
|
176
docs/source/decoding-with-langugage-models/shallow-fusion.rst
Normal file
176
docs/source/decoding-with-langugage-models/shallow-fusion.rst
Normal file
@ -0,0 +1,176 @@
|
||||
.. _shallow_fusion:
|
||||
|
||||
Shallow fusion for Transducer
|
||||
=================================
|
||||
|
||||
External language models (LM) are commonly used to improve WERs for E2E ASR models.
|
||||
This tutorial shows you how to perform ``shallow fusion`` with an external LM
|
||||
to improve the word-error-rate of a transducer model.
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the recipe
|
||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||
However, you can easily apply shallow fusion to other recipes.
|
||||
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
|
||||
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU for decoding.
|
||||
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
These are already good numbers! But we can further improve it by using shallow fusion with external LM.
|
||||
Training a language model usually takes a long time, we can download a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
|
||||
.. note::
|
||||
|
||||
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
|
||||
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
|
||||
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
|
||||
|
||||
To use shallow fusion for decoding, we can execute the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.29
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--use-shallow-fusion 1 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
Note that we set ``--decoding-method modified_beam_search_lm_shallow_fusion`` and ``--use-shallow-fusion True``
|
||||
to use shallow fusion. ``--lm-type`` specifies the type of neural LM we are going to use, you can either choose
|
||||
between ``rnn`` or ``transformer``. The following three arguments are associated with the rnn:
|
||||
|
||||
- ``--rnn-lm-embedding-dim``
|
||||
The embedding dimension of the RNN LM
|
||||
|
||||
- ``--rnn-lm-hidden-dim``
|
||||
The hidden dimension of the RNN LM
|
||||
|
||||
- ``--rnn-lm-num-layers``
|
||||
The number of RNN layers in the RNN LM.
|
||||
|
||||
|
||||
The decoding result obtained with the above command are shown below.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.77 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.08 best for test-other
|
||||
|
||||
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
|
||||
A few parameters can be tuned to further boost the performance of shallow fusion:
|
||||
|
||||
- ``--lm-scale``
|
||||
|
||||
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
|
||||
the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
|
||||
|
||||
- ``--beam-size``
|
||||
|
||||
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
|
||||
|
||||
Here, we also show how `--beam-size` effect the WER and decoding time:
|
||||
|
||||
.. list-table:: WERs and decoding time (on test-clean) of shallow fusion with different beam sizes
|
||||
:widths: 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
- Decoding time on test-clean (s)
|
||||
* - 4
|
||||
- 2.77
|
||||
- 7.08
|
||||
- 262
|
||||
* - 8
|
||||
- 2.62
|
||||
- 6.65
|
||||
- 352
|
||||
* - 12
|
||||
- 2.58
|
||||
- 6.65
|
||||
- 488
|
||||
|
||||
As we see, a larger beam size during shallow fusion improves the WER, but is also slower.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -34,3 +34,8 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
||||
|
||||
contributing/index
|
||||
huggingface/index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
decoding-with-langugage-models/index
|
@ -3,64 +3,91 @@
|
||||
Installation
|
||||
============
|
||||
|
||||
- |os|
|
||||
- |device|
|
||||
- |python_versions|
|
||||
- |torch_versions|
|
||||
- |k2_versions|
|
||||
|
||||
.. |os| image:: ./images/os-Linux_macOS-ff69b4.svg
|
||||
:alt: Supported operating systems
|
||||
|
||||
.. |device| image:: ./images/device-CPU_CUDA-orange.svg
|
||||
:alt: Supported devices
|
||||
|
||||
.. |python_versions| image:: ./images/python-gt-v3.6-blue.svg
|
||||
:alt: Supported python versions
|
||||
|
||||
.. |torch_versions| image:: ./images/torch-gt-v1.6.0-green.svg
|
||||
:alt: Supported PyTorch versions
|
||||
|
||||
.. |k2_versions| image:: ./images/k2-gt-v1.9-blueviolet.svg
|
||||
:alt: Supported k2 versions
|
||||
|
||||
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
|
||||
`lhotse <https://github.com/lhotse-speech/lhotse>`_.
|
||||
|
||||
We recommend you to use the following steps to install the dependencies.
|
||||
We recommend that you use the following steps to install the dependencies.
|
||||
|
||||
- (0) Install PyTorch and torchaudio
|
||||
- (1) Install k2
|
||||
- (2) Install lhotse
|
||||
- (0) Install CUDA toolkit and cuDNN
|
||||
- (1) Install PyTorch and torchaudio
|
||||
- (2) Install k2
|
||||
- (3) Install lhotse
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. caution::
|
||||
|
||||
99% users who have issues about the installation are using conda.
|
||||
|
||||
.. hint::
|
||||
|
||||
We suggest that you use ``pip install`` to install PyTorch.
|
||||
|
||||
You can use the following command to create a virutal environment in Python:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python3 -m venv ./my_env
|
||||
source ./my_env/bin/activate
|
||||
|
||||
.. caution::
|
||||
|
||||
Installation order matters.
|
||||
|
||||
(0) Install PyTorch and torchaudio
|
||||
(0) Install CUDA toolkit and cuDNN
|
||||
----------------------------------
|
||||
|
||||
Please refer to
|
||||
`<https://k2-fsa.github.io/k2/installation/cuda-cudnn.html>`_
|
||||
to install CUDA and cuDNN.
|
||||
|
||||
|
||||
(1) Install PyTorch and torchaudio
|
||||
----------------------------------
|
||||
|
||||
Please refer `<https://pytorch.org/>`_ to install PyTorch
|
||||
and torchaudio.
|
||||
|
||||
.. hint::
|
||||
|
||||
(1) Install k2
|
||||
You can also go to `<https://download.pytorch.org/whl/torch_stable.html>`_
|
||||
to download pre-compiled wheels and install them.
|
||||
|
||||
.. caution::
|
||||
|
||||
Please install torch and torchaudio at the same time.
|
||||
|
||||
|
||||
(2) Install k2
|
||||
--------------
|
||||
|
||||
Please refer to `<https://k2-fsa.github.io/k2/installation/index.html>`_
|
||||
to install ``k2``.
|
||||
|
||||
.. CAUTION::
|
||||
.. caution::
|
||||
|
||||
You need to install ``k2`` with a version at least **v1.9**.
|
||||
Please don't change your installed PyTorch after you have installed k2.
|
||||
|
||||
.. HINT::
|
||||
.. note::
|
||||
|
||||
If you have already installed PyTorch and don't want to replace it,
|
||||
please install a version of ``k2`` that is compiled against the version
|
||||
of PyTorch you are using.
|
||||
We suggest that you install k2 from source by following
|
||||
`<https://k2-fsa.github.io/k2/installation/from_source.html>`_
|
||||
or
|
||||
`<https://k2-fsa.github.io/k2/installation/for_developers.html>`_.
|
||||
|
||||
(2) Install lhotse
|
||||
.. hint::
|
||||
|
||||
Please always install the latest version of k2.
|
||||
|
||||
(3) Install lhotse
|
||||
------------------
|
||||
|
||||
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
|
||||
@ -75,8 +102,7 @@ to install ``lhotse``.
|
||||
|
||||
to install the latest version of lhotse.
|
||||
|
||||
|
||||
(3) Download icefall
|
||||
(4) Download icefall
|
||||
--------------------
|
||||
|
||||
``icefall`` is a collection of Python scripts; what you need is to download it
|
||||
@ -338,44 +364,42 @@ The log of running ``./prepare.sh`` is:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2021-08-23 19:27:26 (prepare.sh:24:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
|
||||
2021-08-23 19:27:26 (prepare.sh:27:main) stage 0: Download data
|
||||
Downloading waves_yesno.tar.gz: 4.49MB [00:03, 1.39MB/s]
|
||||
2021-08-23 19:27:30 (prepare.sh:36:main) Stage 1: Prepare yesno manifest
|
||||
2021-08-23 19:27:31 (prepare.sh:42:main) Stage 2: Compute fbank for yesno
|
||||
2021-08-23 19:27:32,803 INFO [compute_fbank_yesno.py:52] Processing train
|
||||
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:01<00:00, 80.57it/s]
|
||||
2021-08-23 19:27:34,085 INFO [compute_fbank_yesno.py:52] Processing test
|
||||
Extracting and storing features: 100%|______________________________________________________________| 30/30 [00:00<00:00, 248.21it/s]
|
||||
2021-08-23 19:27:34 (prepare.sh:48:main) Stage 3: Prepare lang
|
||||
2021-08-23 19:27:35 (prepare.sh:63:main) Stage 4: Prepare G
|
||||
/tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
|
||||
d(std::istream&):79
|
||||
2023-05-12 17:55:21 (prepare.sh:27:main) dl_dir: /tmp/icefall/egs/yesno/ASR/download
|
||||
2023-05-12 17:55:21 (prepare.sh:30:main) Stage 0: Download data
|
||||
/tmp/icefall/egs/yesno/ASR/download/waves_yesno.tar.gz: 100%|_______________________________________________________________| 4.70M/4.70M [06:54<00:00, 11.4kB/s]
|
||||
2023-05-12 18:02:19 (prepare.sh:39:main) Stage 1: Prepare yesno manifest
|
||||
2023-05-12 18:02:21 (prepare.sh:45:main) Stage 2: Compute fbank for yesno
|
||||
2023-05-12 18:02:23,199 INFO [compute_fbank_yesno.py:65] Processing train
|
||||
Extracting and storing features: 100%|_______________________________________________________________| 90/90 [00:00<00:00, 212.60it/s]
|
||||
2023-05-12 18:02:23,640 INFO [compute_fbank_yesno.py:65] Processing test
|
||||
Extracting and storing features: 100%|_______________________________________________________________| 30/30 [00:00<00:00, 304.53it/s]
|
||||
2023-05-12 18:02:24 (prepare.sh:51:main) Stage 3: Prepare lang
|
||||
2023-05-12 18:02:26 (prepare.sh:66:main) Stage 4: Prepare G
|
||||
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):79
|
||||
[I] Reading \data\ section.
|
||||
/tmp/pip-install-fcordre9/kaldilm_6899d26f2d684ad48f21025950cd2866/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Rea
|
||||
d(std::istream&):140
|
||||
/project/kaldilm/csrc/arpa_file_parser.cc:void kaldilm::ArpaFileParser::Read(std::istream&):140
|
||||
[I] Reading \1-grams: section.
|
||||
2021-08-23 19:27:35 (prepare.sh:89:main) Stage 5: Compile HLG
|
||||
2021-08-23 19:27:35,928 INFO [compile_hlg.py:120] Processing data/lang_phone
|
||||
2021-08-23 19:27:35,929 INFO [lexicon.py:116] Converting L.pt to Linv.pt
|
||||
2021-08-23 19:27:35,931 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
|
||||
2021-08-23 19:27:35,932 INFO [compile_hlg.py:52] Loading G.fst.txt
|
||||
2021-08-23 19:27:35,932 INFO [compile_hlg.py:62] Intersecting L and G
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:64] LG shape: (4, None)
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:66] Connecting LG
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
|
||||
2021-08-23 19:27:35,933 INFO [compile_hlg.py:71] Determinizing LG
|
||||
2021-08-23 19:27:35,934 INFO [compile_hlg.py:74] <class '_k2.RaggedInt'>
|
||||
2021-08-23 19:27:35,934 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
|
||||
2021-08-23 19:27:35,934 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
|
||||
2021-08-23 19:27:35,934 INFO [compile_hlg.py:87] LG shape after k2.remove_epsilon: (6, None)
|
||||
2021-08-23 19:27:35,935 INFO [compile_hlg.py:92] Arc sorting LG
|
||||
2021-08-23 19:27:35,935 INFO [compile_hlg.py:95] Composing H and LG
|
||||
2021-08-23 19:27:35,935 INFO [compile_hlg.py:102] Connecting LG
|
||||
2021-08-23 19:27:35,935 INFO [compile_hlg.py:105] Arc sorting LG
|
||||
2021-08-23 19:27:35,936 INFO [compile_hlg.py:107] HLG.shape: (8, None)
|
||||
2021-08-23 19:27:35,936 INFO [compile_hlg.py:123] Saving HLG.pt to data/lang_phone
|
||||
2023-05-12 18:02:26 (prepare.sh:92:main) Stage 5: Compile HLG
|
||||
2023-05-12 18:02:28,581 INFO [compile_hlg.py:124] Processing data/lang_phone
|
||||
2023-05-12 18:02:28,582 INFO [lexicon.py:171] Converting L.pt to Linv.pt
|
||||
2023-05-12 18:02:28,609 INFO [compile_hlg.py:48] Building ctc_topo. max_token_id: 3
|
||||
2023-05-12 18:02:28,610 INFO [compile_hlg.py:52] Loading G.fst.txt
|
||||
2023-05-12 18:02:28,611 INFO [compile_hlg.py:62] Intersecting L and G
|
||||
2023-05-12 18:02:28,613 INFO [compile_hlg.py:64] LG shape: (4, None)
|
||||
2023-05-12 18:02:28,613 INFO [compile_hlg.py:66] Connecting LG
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:68] LG shape after k2.connect: (4, None)
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:70] <class 'torch.Tensor'>
|
||||
2023-05-12 18:02:28,614 INFO [compile_hlg.py:71] Determinizing LG
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:74] <class '_k2.ragged.RaggedTensor'>
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:76] Connecting LG after k2.determinize
|
||||
2023-05-12 18:02:28,615 INFO [compile_hlg.py:79] Removing disambiguation symbols on LG
|
||||
2023-05-12 18:02:28,616 INFO [compile_hlg.py:91] LG shape after k2.remove_epsilon: (6, None)
|
||||
2023-05-12 18:02:28,617 INFO [compile_hlg.py:96] Arc sorting LG
|
||||
2023-05-12 18:02:28,617 INFO [compile_hlg.py:99] Composing H and LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:106] Connecting LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:109] Arc sorting LG
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:111] HLG.shape: (8, None)
|
||||
2023-05-12 18:02:28,619 INFO [compile_hlg.py:127] Saving HLG.pt to data/lang_phone
|
||||
|
||||
|
||||
Training
|
||||
@ -408,49 +432,53 @@ The training log is given below:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2021-08-23 19:30:31,072 INFO [train.py:465] Training started
|
||||
2021-08-23 19:30:31,072 INFO [train.py:466] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01,
|
||||
'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, '
|
||||
best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_doub
|
||||
le_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'feature_dir': PosixPath('data/fbank'
|
||||
), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0
|
||||
, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
|
||||
2021-08-23 19:30:31,074 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2021-08-23 19:30:31,098 INFO [asr_datamodule.py:146] About to get train cuts
|
||||
2021-08-23 19:30:31,098 INFO [asr_datamodule.py:240] About to get train cuts
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:149] About to create train dataset
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:200] Using SingleCutSampler.
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:206] About to create train dataloader
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:219] About to get test cuts
|
||||
2021-08-23 19:30:31,102 INFO [asr_datamodule.py:246] About to get test cuts
|
||||
2021-08-23 19:30:31,357 INFO [train.py:416] Epoch 0, batch 0, batch avg loss 1.0789, total avg loss: 1.0789, batch size: 4
|
||||
2021-08-23 19:30:31,848 INFO [train.py:416] Epoch 0, batch 10, batch avg loss 0.5356, total avg loss: 0.7556, batch size: 4
|
||||
2021-08-23 19:30:32,301 INFO [train.py:432] Epoch 0, valid loss 0.9972, best valid loss: 0.9972 best valid epoch: 0
|
||||
2021-08-23 19:30:32,805 INFO [train.py:416] Epoch 0, batch 20, batch avg loss 0.2436, total avg loss: 0.5717, batch size: 3
|
||||
2021-08-23 19:30:33,109 INFO [train.py:432] Epoch 0, valid loss 0.4167, best valid loss: 0.4167 best valid epoch: 0
|
||||
2021-08-23 19:30:33,121 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-0.pt
|
||||
2021-08-23 19:30:33,325 INFO [train.py:416] Epoch 1, batch 0, batch avg loss 0.2214, total avg loss: 0.2214, batch size: 5
|
||||
2021-08-23 19:30:33,798 INFO [train.py:416] Epoch 1, batch 10, batch avg loss 0.0781, total avg loss: 0.1343, batch size: 5
|
||||
2021-08-23 19:30:34,065 INFO [train.py:432] Epoch 1, valid loss 0.0859, best valid loss: 0.0859 best valid epoch: 1
|
||||
2021-08-23 19:30:34,556 INFO [train.py:416] Epoch 1, batch 20, batch avg loss 0.0421, total avg loss: 0.0975, batch size: 3
|
||||
2021-08-23 19:30:34,810 INFO [train.py:432] Epoch 1, valid loss 0.0431, best valid loss: 0.0431 best valid epoch: 1
|
||||
2021-08-23 19:30:34,824 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-1.pt
|
||||
2023-05-12 18:04:59,759 INFO [train.py:481] Training started
|
||||
2023-05-12 18:04:59,759 INFO [train.py:482] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0,
|
||||
'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10,
|
||||
'reduction': 'sum', 'use_double_scores': True, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 15, 'seed': 42, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0,
|
||||
'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2,
|
||||
'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
|
||||
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
|
||||
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
|
||||
'k2-path': 'tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
|
||||
'lhotse-path': 'tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
|
||||
2023-05-12 18:04:59,761 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-05-12 18:04:59,764 INFO [train.py:495] device: cpu
|
||||
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:146] About to get train cuts
|
||||
2023-05-12 18:04:59,791 INFO [asr_datamodule.py:244] About to get train cuts
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:149] About to create train dataset
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:199] Using SingleCutSampler.
|
||||
2023-05-12 18:04:59,852 INFO [asr_datamodule.py:205] About to create train dataloader
|
||||
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-05-12 18:04:59,853 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-05-12 18:04:59,986 INFO [train.py:422] Epoch 0, batch 0, loss[loss=1.065, over 2436.00 frames. ], tot_loss[loss=1.065, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:00,352 INFO [train.py:422] Epoch 0, batch 10, loss[loss=0.4561, over 2828.00 frames. ], tot_loss[loss=0.7076, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:00,691 INFO [train.py:444] Epoch 0, validation loss=0.9002, over 18067.00 frames.
|
||||
2023-05-12 18:05:00,996 INFO [train.py:422] Epoch 0, batch 20, loss[loss=0.2555, over 2695.00 frames. ], tot_loss[loss=0.484, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:01,217 INFO [train.py:444] Epoch 0, validation loss=0.4688, over 18067.00 frames.
|
||||
2023-05-12 18:05:01,251 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-0.pt
|
||||
2023-05-12 18:05:01,389 INFO [train.py:422] Epoch 1, batch 0, loss[loss=0.2532, over 2436.00 frames. ], tot_loss[loss=0.2532, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:01,637 INFO [train.py:422] Epoch 1, batch 10, loss[loss=0.1139, over 2828.00 frames. ], tot_loss[loss=0.1592, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:01,859 INFO [train.py:444] Epoch 1, validation loss=0.1629, over 18067.00 frames.
|
||||
2023-05-12 18:05:02,094 INFO [train.py:422] Epoch 1, batch 20, loss[loss=0.0767, over 2695.00 frames. ], tot_loss[loss=0.118, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:02,350 INFO [train.py:444] Epoch 1, validation loss=0.06778, over 18067.00 frames.
|
||||
2023-05-12 18:05:02,395 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-1.pt
|
||||
|
||||
... ...
|
||||
|
||||
2021-08-23 19:30:49,657 INFO [train.py:416] Epoch 13, batch 0, batch avg loss 0.0109, total avg loss: 0.0109, batch size: 5
|
||||
2021-08-23 19:30:49,984 INFO [train.py:416] Epoch 13, batch 10, batch avg loss 0.0093, total avg loss: 0.0096, batch size: 4
|
||||
2021-08-23 19:30:50,239 INFO [train.py:432] Epoch 13, valid loss 0.0104, best valid loss: 0.0101 best valid epoch: 12
|
||||
2021-08-23 19:30:50,569 INFO [train.py:416] Epoch 13, batch 20, batch avg loss 0.0092, total avg loss: 0.0096, batch size: 2
|
||||
2021-08-23 19:30:50,819 INFO [train.py:432] Epoch 13, valid loss 0.0101, best valid loss: 0.0101 best valid epoch: 13
|
||||
2021-08-23 19:30:50,835 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-13.pt
|
||||
2021-08-23 19:30:51,024 INFO [train.py:416] Epoch 14, batch 0, batch avg loss 0.0105, total avg loss: 0.0105, batch size: 5
|
||||
2021-08-23 19:30:51,317 INFO [train.py:416] Epoch 14, batch 10, batch avg loss 0.0099, total avg loss: 0.0097, batch size: 4
|
||||
2021-08-23 19:30:51,552 INFO [train.py:432] Epoch 14, valid loss 0.0108, best valid loss: 0.0101 best valid epoch: 13
|
||||
2021-08-23 19:30:51,869 INFO [train.py:416] Epoch 14, batch 20, batch avg loss 0.0096, total avg loss: 0.0097, batch size: 5
|
||||
2021-08-23 19:30:52,107 INFO [train.py:432] Epoch 14, valid loss 0.0102, best valid loss: 0.0101 best valid epoch: 13
|
||||
2021-08-23 19:30:52,126 INFO [checkpoint.py:62] Saving checkpoint to tdnn/exp/epoch-14.pt
|
||||
2021-08-23 19:30:52,128 INFO [train.py:537] Done!
|
||||
2023-05-12 18:05:14,789 INFO [train.py:422] Epoch 13, batch 0, loss[loss=0.01056, over 2436.00 frames. ], tot_loss[loss=0.01056, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:15,016 INFO [train.py:422] Epoch 13, batch 10, loss[loss=0.009022, over 2828.00 frames. ], tot_loss[loss=0.009985, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:15,271 INFO [train.py:444] Epoch 13, validation loss=0.01088, over 18067.00 frames.
|
||||
2023-05-12 18:05:15,497 INFO [train.py:422] Epoch 13, batch 20, loss[loss=0.01174, over 2695.00 frames. ], tot_loss[loss=0.01077, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:15,747 INFO [train.py:444] Epoch 13, validation loss=0.01087, over 18067.00 frames.
|
||||
2023-05-12 18:05:15,783 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-13.pt
|
||||
2023-05-12 18:05:15,921 INFO [train.py:422] Epoch 14, batch 0, loss[loss=0.01045, over 2436.00 frames. ], tot_loss[loss=0.01045, over 2436.00 frames. ], batch size: 4
|
||||
2023-05-12 18:05:16,146 INFO [train.py:422] Epoch 14, batch 10, loss[loss=0.008957, over 2828.00 frames. ], tot_loss[loss=0.009903, over 22192.90 frames. ], batch size: 4
|
||||
2023-05-12 18:05:16,374 INFO [train.py:444] Epoch 14, validation loss=0.01092, over 18067.00 frames.
|
||||
2023-05-12 18:05:16,598 INFO [train.py:422] Epoch 14, batch 20, loss[loss=0.01169, over 2695.00 frames. ], tot_loss[loss=0.01065, over 34971.47 frames. ], batch size: 5
|
||||
2023-05-12 18:05:16,824 INFO [train.py:444] Epoch 14, validation loss=0.01077, over 18067.00 frames.
|
||||
2023-05-12 18:05:16,862 INFO [checkpoint.py:75] Saving checkpoint to tdnn/exp/epoch-14.pt
|
||||
2023-05-12 18:05:16,865 INFO [train.py:555] Done!
|
||||
|
||||
Decoding
|
||||
~~~~~~~~
|
||||
@ -465,22 +493,25 @@ The decoding log is:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2021-08-23 19:35:30,192 INFO [decode.py:249] Decoding started
|
||||
2021-08-23 19:35:30,192 INFO [decode.py:250] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2}
|
||||
2021-08-23 19:35:30,193 INFO [lexicon.py:113] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2021-08-23 19:35:30,213 INFO [decode.py:259] device: cpu
|
||||
2021-08-23 19:35:30,217 INFO [decode.py:279] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
/tmp/icefall/icefall/checkpoint.py:146: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch.
|
||||
It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
|
||||
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:450.)
|
||||
avg[k] //= n
|
||||
2021-08-23 19:35:30,220 INFO [asr_datamodule.py:219] About to get test cuts
|
||||
2021-08-23 19:35:30,220 INFO [asr_datamodule.py:246] About to get test cuts
|
||||
2021-08-23 19:35:30,409 INFO [decode.py:190] batch 0/8, cuts processed until now is 4
|
||||
2021-08-23 19:35:30,571 INFO [decode.py:228] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||
2021-08-23 19:35:30,572 INFO [utils.py:317] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
2021-08-23 19:35:30,573 INFO [decode.py:236] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||
2021-08-23 19:35:30,573 INFO [decode.py:299] Done!
|
||||
2023-05-12 18:08:30,482 INFO [decode.py:263] Decoding started
|
||||
2023-05-12 18:08:30,483 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23,
|
||||
'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'),
|
||||
'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True,
|
||||
'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3b7f09fa35e72589914f67089c0da9f196a92ca4', 'k2-git-date': 'Mon May 8 22:58:45 2023',
|
||||
'lhotse-version': '1.15.0.dev+git.6fcfced.clean', 'torch-version': '2.0.0+cu118', 'torch-cuda-available': False, 'torch-cuda-version': '11.8', 'python-version': '3.1', 'icefall-git-branch': 'master',
|
||||
'icefall-git-sha1': '30bde4b-clean', 'icefall-git-date': 'Thu May 11 17:37:47 2023', 'icefall-path': '/tmp/icefall',
|
||||
'k2-path': '/tmp/lib/python3.10/site-packages/k2-1.24.3.dev20230512+cuda11.8.torch2.0.0-py3.10-linux-x86_64.egg/k2/__init__.py',
|
||||
'lhotse-path': '/tmp/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'host', 'IP address': '0.0.0.0'}}
|
||||
2023-05-12 18:08:30,483 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||
2023-05-12 18:08:30,487 INFO [decode.py:273] device: cpu
|
||||
2023-05-12 18:08:30,513 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:218] About to get test cuts
|
||||
2023-05-12 18:08:30,521 INFO [asr_datamodule.py:252] About to get test cuts
|
||||
2023-05-12 18:08:30,675 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
|
||||
2023-05-12 18:08:30,923 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||
2023-05-12 18:08:30,924 INFO [utils.py:558] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||
2023-05-12 18:08:30,925 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||
2023-05-12 18:08:30,925 INFO [decode.py:316] Done!
|
||||
|
||||
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
|
||||
|
||||
|
@ -276,7 +276,7 @@ The result looks like below:
|
||||
|
||||
7767517
|
||||
2029 2547
|
||||
SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31
|
||||
SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 15=1 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31
|
||||
Input in0 0 1 in0
|
||||
|
||||
**Explanation**
|
||||
@ -300,6 +300,9 @@ The result looks like below:
|
||||
- ``3=7``, 3 is the key and 7 is the value of for the amount of padding
|
||||
used in the Conv2DSubsampling layer. It should be 7 for zipformer
|
||||
if you don't change zipformer.py.
|
||||
- ``15=1``, attribute 15, this is the model version. Starting from
|
||||
`sherpa-ncnn`_ v2.0, we require that the model version has to
|
||||
be >= 1.
|
||||
- ``-23316=5,2,4,3,2,4``, attribute 16, this is an array attribute.
|
||||
It is attribute 16 since -23300 - (-23316) = 16.
|
||||
The first element of the array is the length of the array, which is 5 in our case.
|
||||
@ -338,6 +341,8 @@ The result looks like below:
|
||||
+----------+--------------------------------------------+
|
||||
| 3 | 7 (if you don't change code) |
|
||||
+----------+--------------------------------------------+
|
||||
| 15 | 1 (The model version) |
|
||||
+----------+--------------------------------------------+
|
||||
|-23316 | ``--num-encoder-layer`` |
|
||||
+----------+--------------------------------------------+
|
||||
|-23317 | ``--encoder-dims`` |
|
||||
|
@ -3,6 +3,15 @@ Export to ONNX
|
||||
|
||||
In this section, we describe how to export models to `ONNX`_.
|
||||
|
||||
.. hint::
|
||||
|
||||
Before you continue, please run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install onnx
|
||||
|
||||
|
||||
In each recipe, there is a file called ``export-onnx.py``, which is used
|
||||
to export trained models to `ONNX`_.
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
Distillation with HuBERT
|
||||
========================
|
||||
|
||||
This tutorial shows you how to perform knowledge distillation in `icefall`_
|
||||
This tutorial shows you how to perform knowledge distillation in `icefall <https://github.com/k2-fsa/icefall>`_
|
||||
with the `LibriSpeech`_ dataset. The distillation method
|
||||
used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD).
|
||||
Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation <https://arxiv.org/abs/2211.00508>`_
|
||||
@ -13,7 +13,7 @@ for more details about MVQ-KD.
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_.
|
||||
Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes
|
||||
with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you
|
||||
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -217,7 +217,7 @@ the following command.
|
||||
--exp-dir $exp_dir \
|
||||
--enable-distillation True
|
||||
|
||||
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`_.
|
||||
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`__.
|
||||
|
||||
That's all! Feel free to experiment with your own setups and report your results.
|
||||
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`__.
|
||||
|
@ -8,10 +8,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
|
||||
We will take pruned_transducer_stateless4 as an example in this tutorial.
|
||||
|
||||
.. HINT::
|
||||
@ -237,7 +237,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
|
||||
|
||||
.. NOTE::
|
||||
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
|
||||
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
|
||||
|
||||
|
||||
@ -529,13 +529,13 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following links:
|
||||
|
||||
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`_
|
||||
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`__
|
||||
|
||||
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`_
|
||||
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`__
|
||||
|
||||
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`_
|
||||
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`__
|
||||
|
||||
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`_
|
||||
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`__
|
||||
|
||||
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||
for the details of the above pretrained models
|
||||
|
@ -45,9 +45,9 @@ the input features.
|
||||
|
||||
We have three variants of Emformer models in ``icefall``.
|
||||
|
||||
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`_.
|
||||
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`__.
|
||||
- ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio,
|
||||
ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model.
|
||||
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`_.
|
||||
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`__.
|
||||
- ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that
|
||||
it uses a simplified memory bank. See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless2>`_.
|
||||
|
@ -6,10 +6,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
|
||||
We will take pruned_transducer_stateless4 as an example in this tutorial.
|
||||
|
||||
.. HINT::
|
||||
@ -264,7 +264,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
|
||||
|
||||
.. NOTE::
|
||||
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
|
||||
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`__,
|
||||
|
||||
.. HINT::
|
||||
|
||||
@ -642,7 +642,7 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following links:
|
||||
|
||||
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_
|
||||
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__
|
||||
|
||||
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||
for the details of the above pretrained models
|
||||
|
@ -17,6 +17,7 @@ The following table lists the differences among them.
|
||||
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
|
||||
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
|
||||
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
|
||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -2,6 +2,109 @@
|
||||
|
||||
### Aishell training result(Stateless Transducer)
|
||||
|
||||
#### Pruned transducer stateless 7
|
||||
|
||||
[./pruned_transducer_stateless7](./pruned_transducer_stateless7)
|
||||
|
||||
It's Zipformer with Pruned RNNT loss.
|
||||
|
||||
| | test | dev | comment |
|
||||
|------------------------|------|------|---------------------------------------|
|
||||
| greedy search | 5.02 | 4.61 | --epoch 42 --avg 6 --max-duration 600 |
|
||||
| modified beam search | 4.81 | 4.4 | --epoch 42 --avg 6 --max-duration 600 |
|
||||
| fast beam search | 4.91 | 4.52 | --epoch 42 --avg 6 --max-duration 600 |
|
||||
|
||||
Training command is:
|
||||
|
||||
```bash
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--context-size 1 \
|
||||
--max-duration 300
|
||||
```
|
||||
|
||||
**Caution**: It uses `--context-size=1`.
|
||||
|
||||
The tensorboard log is available at
|
||||
<https://tensorboard.dev/experiment/MHYo3ApfQxaCdYLr38cQOQ>
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
for m in greedy_search modified_beam_search fast_beam_search ; do
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 42 \
|
||||
--avg 6 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 300 \
|
||||
--context-size 1 \
|
||||
--decoding-method $m
|
||||
|
||||
done
|
||||
```
|
||||
|
||||
Pretrained models, training logs, decoding logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21>
|
||||
#### Pruned transducer stateless 7 (zipformer)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/986>
|
||||
|
||||
[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe)
|
||||
|
||||
**Note**: The modeling units are byte level BPEs
|
||||
|
||||
The best results I have gotten are:
|
||||
|
||||
Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments
|
||||
-- | -- | -- | -- | -- | --
|
||||
500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29
|
||||
|
||||
The training command:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="4,5,6,7"
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 800 \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--exp-dir pruned_transducer_stateless7_bbpe/exp \
|
||||
--lr-epochs 6 \
|
||||
--master-port 12535
|
||||
```
|
||||
|
||||
The decoding command:
|
||||
|
||||
```
|
||||
for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 48 \
|
||||
--avg 29 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-sym-per-frame 1 \
|
||||
--ngram-lm-scale 0.25 \
|
||||
--ilme-scale 0.2 \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--max-duration 2000 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
|
||||
|
||||
#### Pruned transducer stateless 3
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/436>
|
||||
|
1
egs/aishell/ASR/local/compile_lg.py
Symbolic link
1
egs/aishell/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`:
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
@ -189,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
|
||||
return tokens
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
lang_dir = Path("data/lang_char")
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
text_file = lang_dir / "text"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
267
egs/aishell/ASR/local/prepare_lang_bbpe.py
Executable file
267
egs/aishell/ASR/local/prepare_lang_bbpe.py
Executable file
@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
This script takes as input `lang_dir`, which should contain::
|
||||
|
||||
- lang_dir/bbpe.model,
|
||||
- lang_dir/words.txt
|
||||
|
||||
and generates the following files in the directory `lang_dir`:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
from icefall.byte_utils import byte_encode
|
||||
from icefall.utils import str2bool, tokenize_by_CJK_char
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||
|
||||
arcs = []
|
||||
|
||||
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||
assert token2id["<blk>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, pieces in lexicon:
|
||||
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [token2id[i] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last piece of this word
|
||||
i = len(pieces) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
model_file: str, words: List[str], oov: str
|
||||
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||
"""Generate a lexicon from a BPE model.
|
||||
|
||||
Args:
|
||||
model_file:
|
||||
Path to a sentencepiece model.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
oov:
|
||||
The out of vocabulary word in lexicon.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
- A dict whose keys are words and values are the corresponding
|
||||
word pieces.
|
||||
- A dict representing the token symbol, mapping from tokens to IDs.
|
||||
"""
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(model_file))
|
||||
|
||||
# Convert word to word piece IDs instead of word piece strings
|
||||
# to avoid OOV tokens.
|
||||
encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words]
|
||||
words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int)
|
||||
|
||||
# Now convert word piece IDs back to word piece strings.
|
||||
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
|
||||
|
||||
lexicon = []
|
||||
for word, pieces in zip(words, words_pieces):
|
||||
lexicon.append((word, pieces))
|
||||
|
||||
lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
|
||||
|
||||
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
|
||||
|
||||
return lexicon, token2id
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--oov",
|
||||
type=str,
|
||||
default="<UNK>",
|
||||
help="The out of vocabulary word in lexicon.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
|
||||
See "test/test_bpe_lexicon.py" for usage.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
model_file = lang_dir / "bbpe.model"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
|
||||
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
next_token_id = max(token_sym_table.values()) + 1
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table[disambig] = next_token_id
|
||||
next_token_id += 1
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
113
egs/aishell/ASR/local/train_bbpe_model.py
Executable file
113
egs/aishell/ASR/local/train_bbpe_model.py
Executable file
@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
#
|
||||
# Due to an issue reported in
|
||||
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
from icefall import byte_encode, tokenize_by_CJK_char
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
help="Vocabulary size for BPE training",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _convert_to_bchar(in_path: str, out_path: str):
|
||||
with open(out_path, "w") as f:
|
||||
for line in open(in_path, "r").readlines():
|
||||
f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
vocab_size = args.vocab_size
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
unk_id = len(user_defined_symbols)
|
||||
# Note: unk_id is fixed to 2.
|
||||
# If you change it, you should also change other
|
||||
# places that are using it.
|
||||
|
||||
temp = tempfile.NamedTemporaryFile()
|
||||
train_text = temp.name
|
||||
|
||||
_convert_to_bchar(args.transcript, train_text)
|
||||
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if not model_file.is_file():
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=train_text,
|
||||
vocab_size=vocab_size,
|
||||
model_type=model_type,
|
||||
model_prefix=model_prefix,
|
||||
input_sentence_size=input_sentence_size,
|
||||
character_coverage=character_coverage,
|
||||
user_defined_symbols=user_defined_symbols,
|
||||
unk_id=unk_id,
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
)
|
||||
else:
|
||||
print(f"{model_file} exists - skipping")
|
||||
return
|
||||
|
||||
shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -35,6 +35,15 @@ dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bbpe_xxx,
|
||||
# data/lang_bbpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
# 2000
|
||||
# 1000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
@ -47,20 +56,6 @@ log() {
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: Download LM"
|
||||
# We assume that you have installed the git-lfs, if not, you could install it
|
||||
# using: `sudo apt-get install git-lfs && git-lfs install`
|
||||
git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
|
||||
|
||||
if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
|
||||
git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
|
||||
pushd $dl_dir/lm
|
||||
git lfs pull --include "3-gram.unpruned.arpa"
|
||||
popd
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: Download data"
|
||||
|
||||
@ -134,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
lang_phone_dir=data/lang_phone
|
||||
lang_char_dir=data/lang_char
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
mkdir -p $lang_phone_dir
|
||||
@ -183,45 +177,107 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
lang_char_dir=data/lang_char
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
mkdir -p $lang_char_dir
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp $lang_phone_dir/words.txt $lang_char_dir
|
||||
|
||||
# The transcripts in training set, generated in stage 5
|
||||
cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
|
||||
|
||||
cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
|
||||
cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text
|
||||
cut -d " " -f 2- > $lang_char_dir/text
|
||||
|
||||
(echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
|
||||
> $lang_char_dir/words.txt
|
||||
|
||||
cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
|
||||
| awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
|
||||
|
||||
num_lines=$(< $lang_char_dir/words.txt wc -l)
|
||||
(echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
|
||||
>> $lang_char_dir/words.txt
|
||||
|
||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||
./local/prepare_char.py
|
||||
./local/prepare_char.py --lang-dir $lang_char_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
log "Stage 7: Prepare Byte BPE based lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
cp $lang_char_dir/words.txt $lang_dir
|
||||
cp $lang_char_dir/text $lang_dir
|
||||
|
||||
if [ ! -f $lang_dir/bbpe.model ]; then
|
||||
./local/train_bbpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/text
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bbpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Prepare G"
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
|
||||
# Train LM on transcripts
|
||||
if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
|
||||
python3 ./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text $lang_char_dir/transcript_words.txt \
|
||||
-lm data/lm/3-gram.unpruned.arpa
|
||||
fi
|
||||
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_phone_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
$dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Compile HLG"
|
||||
./local/compile_hlg.py --lang-dir $lang_phone_dir
|
||||
./local/compile_hlg.py --lang-dir $lang_char_dir
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_char_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Generate LM training data"
|
||||
log "Stage 9: Compile LG & HLG"
|
||||
./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
|
||||
./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
|
||||
done
|
||||
|
||||
./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
|
||||
./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bbpe_${vocab_size}
|
||||
./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Generate LM training data"
|
||||
|
||||
log "Processing char based data"
|
||||
out_dir=data/lm_training_char
|
||||
@ -267,8 +323,8 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Sort LM training data"
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Sort LM training data"
|
||||
# Sort LM training data by sentence length in descending order
|
||||
# for ease of training.
|
||||
#
|
||||
@ -295,7 +351,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
--out-statistics $out_dir/statistics-test.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 11: Train RNN LM model"
|
||||
python ../../../icefall/rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
@ -306,9 +362,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
--hidden-dim 512 \
|
||||
--num-layers 2 \
|
||||
--batch-size 400 \
|
||||
--exp-dir rnnlm_char/exp_aishell1_small \
|
||||
--lm-data data/lm_char/sorted_lm_data_aishell1.pt \
|
||||
--lm-data-valid data/lm_char/sorted_lm_data_valid.pt \
|
||||
--exp-dir rnnlm_char/exp \
|
||||
--lm-data $out_dir/sorted_lm_data.pt \
|
||||
--lm-data-valid $out_dir/sorted_lm_data-valid.pt \
|
||||
--vocab-size 4336 \
|
||||
--master-port 12345
|
||||
fi
|
||||
|
1
egs/aishell/ASR/pruned_transducer_stateless7/aishell.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/aishell.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless3/aishell.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7/asr_datamodule.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless_modified-2/asr_datamodule.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7/beam_search.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless3/beam_search.py
|
685
egs/aishell/ASR/pruned_transducer_stateless7/decode.py
Executable file
685
egs/aishell/ASR/pruned_transducer_stateless7/decode.py
Executable file
@ -0,0 +1,685 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from aishell import AIShell
|
||||
from asr_datamodule import AsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless3/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
token_table:
|
||||
It maps token ID to a string.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
else:
|
||||
hyp_tokens = []
|
||||
batch_size = encoder_out.size(0)
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyp_tokens.append(hyp)
|
||||
|
||||
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
key = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
key += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
key += "-no-context-words"
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
token_table:
|
||||
It maps a token ID to a string.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
token_table=token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
|
||||
store_transcripts(filename=recog_path, texts=results_char)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-contexts-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts_text = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts_text.append(line.strip())
|
||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(contexts)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
asr_datamodule = AsrDataModule(args)
|
||||
aishell = AIShell(manifest_dir=args.manifest_dir)
|
||||
test_cuts = aishell.test_cuts()
|
||||
dev_cuts = aishell.valid_cuts()
|
||||
test_dl = asr_datamodule.test_dataloaders(test_cuts)
|
||||
dev_dl = asr_datamodule.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test", "dev"]
|
||||
test_dls = [test_dl, dev_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
token_table=lexicon.token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7/decoder.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
|
87
egs/aishell/ASR/pruned_transducer_stateless7/decoder2.py
Normal file
87
egs/aishell/ASR/pruned_transducer_stateless7/decoder2.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size == 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
if torch.jit.is_tracing():
|
||||
# This is for exporting to PNNX via ONNX
|
||||
embedding_out = self.embedding(y)
|
||||
else:
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
|
589
egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
Executable file
589
egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
Executable file
@ -0,0 +1,589 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang
|
||||
# Xiaoyu Yang)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21/
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless7/export-onnx.py \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024"
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||
use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder2 import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import Zipformer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of Zipformer.forward
|
||||
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||
"""
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
|
||||
|
||||
encoder_out = self.encoder_proj(encoder_out)
|
||||
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
321
egs/aishell/ASR/pruned_transducer_stateless7/export.py
Executable file
321
egs/aishell/ASR/pruned_transducer_stateless7/export.py
Executable file
@ -0,0 +1,321 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--lang-dir data/lang_char
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
|
||||
# You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
278
egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py
Executable file
278
egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py
Executable file
@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--lang-dir ./data/lang_char \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7/jit_pretrained.py \
|
||||
--nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
lexicon = Lexicon(args.lang_dir)
|
||||
token_table = lexicon.token_table
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
hyps = [[token_table[t] for t in tokens] for tokens in hyps]
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7/joiner.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7/model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/model.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7/onnx_check.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/onnx_check.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/onnx_check.py
|
419
egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py
Executable file
419
egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py
Executable file
@ -0,0 +1,419 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless3/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-9999-avg-1.onnx
|
||||
- decoder-epoch-9999-avg-1.onnx
|
||||
- joiner-epoch-9999-avg-1.onnx
|
||||
|
||||
3. Run this file
|
||||
|
||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, its shape is (N, T', joiner_dim)
|
||||
- encoder_out_lens, its shape is (N,)
|
||||
"""
|
||||
out = self.encoder.run(
|
||||
[
|
||||
self.encoder.get_outputs()[0].name,
|
||||
self.encoder.get_outputs()[1].name,
|
||||
],
|
||||
{
|
||||
self.encoder.get_inputs()[0].name: x.numpy(),
|
||||
self.encoder.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, joiner_dim)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
# current_encoder_out's shape: (batch_size, joiner_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
logits = model.run_joiner(current_encoder_out, decoder_out)
|
||||
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += symbol_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
context_size = model.context_size
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = token_ids_to_words(hyp[context_size:])
|
||||
s += f"{filename}:\n{words}\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7/optim.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
|
348
egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py
Normal file
348
egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py
Normal file
@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless7/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
token_table = lexicon.token_table
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp_tokens = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp_tokens = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7/scaling.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
|
1255
egs/aishell/ASR/pruned_transducer_stateless7/train.py
Executable file
1255
egs/aishell/ASR/pruned_transducer_stateless7/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1255
egs/aishell/ASR/pruned_transducer_stateless7/train2.py
Executable file
1255
egs/aishell/ASR/pruned_transducer_stateless7/train2.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/aishell/ASR/pruned_transducer_stateless7/zipformer.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/asr_datamodule.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
818
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
Executable file
818
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
Executable file
@ -0,0 +1,818 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import (
|
||||
LmScorer,
|
||||
NgramLm,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
tokenize_by_CJK_char,
|
||||
)
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bbpe_500/bbpe.model",
|
||||
help="Path to the byte BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bbpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
If you use fast_beam_search_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ilme-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for the internal language model estimation.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "fast_beam_search_LG":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
ilme_scale=params.ilme_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
ref_texts = []
|
||||
for tx in supervisions["text"]:
|
||||
ref_texts.append(byte_encode(tokenize_by_CJK_char(tx)))
|
||||
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(ref_texts),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
key += f"_ilme_scale_{params.ilme_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
params.suffix += f"-ilme-scale-{params.ilme_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bbpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
|
||||
test_cuts = aishell.test_cuts()
|
||||
dev_cuts = aishell.valid_cuts()
|
||||
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
dev_dl = aishell.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test", "dev"]
|
||||
test_dls = [test_dl, dev_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/encoder_interface.py
|
320
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py
Executable file
320
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py
Executable file
@ -0,0 +1,320 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/aishell/ASR
|
||||
./pruned_transducer_stateless7_bbpe/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
|
||||
# You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_bbpe/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bbpe_500/bbpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bbpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
274
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
Executable file
274
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
Executable file
@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 49 \
|
||||
--avg 28 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/jit_pretrained.py \
|
||||
--nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall import smart_byte_decode
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = smart_byte_decode(sp.decode(hyp))
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/model.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
|
345
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
Executable file
345
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
Executable file
@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7_bbpe/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
|
||||
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||
--epoch 48 \
|
||||
--avg 29
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) fast beam search
|
||||
./pruned_transducer_stateless7_bbpe/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bbpe_500/bbpe.model \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7_bbpe/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import smart_byte_decode
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(smart_byte_decode(hyp).split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(smart_byte_decode(sp.decode(hyp)).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
|
1250
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
Executable file
1250
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
@ -21,7 +21,7 @@ import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
@ -181,7 +181,16 @@ class AishellAsrDataModule:
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
@ -277,6 +286,10 @@ class AishellAsrDataModule:
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
@ -325,7 +338,7 @@ class AishellAsrDataModule:
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
logging.info("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
|
@ -20,7 +20,7 @@ import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig
|
||||
from lhotse.dataset import (
|
||||
@ -144,6 +144,7 @@ class AsrDataModule:
|
||||
cuts_train: CutSet,
|
||||
on_the_fly_feats: bool,
|
||||
cuts_musan: Optional[CutSet] = None,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -228,6 +229,10 @@ class AsrDataModule:
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
logging.info("About to create train dataloader")
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user