mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Merge remote-tracking branch 'dan/master' into modified-conformer-with-multi-datasets
This commit is contained in:
commit
5bbce704e2
47
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
vendored
Executable file
47
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
vendored
Executable file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./pruned_transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search; do
|
||||
log "$method"
|
||||
|
||||
./pruned_transducer_stateless/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
47
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
vendored
Executable file
47
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
vendored
Executable file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./transducer_stateless2/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search; do
|
||||
log "$method"
|
||||
|
||||
./transducer_stateless2/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
46
.github/scripts/run-pre-trained-conformer-ctc.sh
vendored
Executable file
46
.github/scripts/run-pre-trained-conformer-ctc.sh
vendored
Executable file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
|
||||
git lfs install
|
||||
git clone $repo
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.flac
|
||||
ls -lh $repo/test_wavs/*.flac
|
||||
|
||||
log "CTC decoding"
|
||||
|
||||
./conformer_ctc/pretrained.py \
|
||||
--method ctc-decoding \
|
||||
--num-classes 500 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.flac \
|
||||
$repo/test_wavs/1221-135766-0001.flac \
|
||||
$repo/test_wavs/1221-135766-0002.flac
|
||||
|
||||
log "HLG decoding"
|
||||
|
||||
./conformer_ctc/pretrained.py \
|
||||
--method 1best \
|
||||
--num-classes 500 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||
$repo/test_wavs/1089-134686-0001.flac \
|
||||
$repo/test_wavs/1221-135766-0001.flac \
|
||||
$repo/test_wavs/1221-135766-0002.flac
|
47
.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
vendored
Executable file
47
.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
vendored
Executable file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search; do
|
||||
log "$method"
|
||||
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
47
.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
vendored
Executable file
47
.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
vendored
Executable file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search; do
|
||||
log "$method"
|
||||
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
47
.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
vendored
Executable file
47
.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
vendored
Executable file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/aishell/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./transducer_stateless_modified-2/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||
$repo/test_wavs/BAC009S0764W0123.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search; do
|
||||
log "$method"
|
||||
|
||||
./transducer_stateless_modified-2/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||
$repo/test_wavs/BAC009S0764W0123.wav
|
||||
done
|
47
.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
vendored
Executable file
47
.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
vendored
Executable file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/aishell/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./transducer_stateless_modified/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||
$repo/test_wavs/BAC009S0764W0123.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search; do
|
||||
log "$method"
|
||||
|
||||
./transducer_stateless_modified/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||
$repo/test_wavs/BAC009S0764W0123.wav
|
||||
done
|
60
.github/scripts/run-pre-trained-transducer-stateless.sh
vendored
Executable file
60
.github/scripts/run-pre-trained-transducer-stateless.sh
vendored
Executable file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame $sym \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search; do
|
||||
log "$method"
|
||||
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
||||
|
||||
for method in modified_beam_search beam_search; do
|
||||
log "$method"
|
||||
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method $method \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
done
|
32
.github/scripts/run-pre-trained-transducer.sh
vendored
Executable file
32
.github/scripts/run-pre-trained-transducer.sh
vendored
Executable file
@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
|
||||
|
||||
log "Downloading pre-trained model from $repo_url"
|
||||
git lfs install
|
||||
git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
log "Display test files"
|
||||
tree $repo/
|
||||
soxi $repo/test_wavs/*.wav
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
log "Beam search decoding"
|
||||
|
||||
./transducer/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint $repo/exp/pretrained.pt \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
104
.github/workflows/run-librispeech-2022-03-12.yml
vendored
104
.github/workflows/run-librispeech-2022-03-12.yml
vendored
@ -40,11 +40,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install graphviz
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
@ -77,104 +72,11 @@ jobs:
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Download pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs
|
||||
mkdir -p ~/tmp
|
||||
cd ~/tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||
|
||||
- name: Display test files
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install tree sox
|
||||
tree ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||
soxi ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav
|
||||
ls -lh ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 1)
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||
cd egs/librispeech/ASR
|
||||
./pruned_transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--checkpoint $dir/exp/pretrained.pt \
|
||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||
$dir/test_wavs/1089-134686-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 2)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||
cd egs/librispeech/ASR
|
||||
./pruned_transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 2 \
|
||||
--checkpoint $dir/exp/pretrained.pt \
|
||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||
$dir/test_wavs/1089-134686-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 3)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||
cd egs/librispeech/ASR
|
||||
./pruned_transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 3 \
|
||||
--checkpoint $dir/exp/pretrained.pt \
|
||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||
$dir/test_wavs/1089-134686-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||
cd egs/librispeech/ASR
|
||||
./pruned_transducer_stateless/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint $dir/exp/pretrained.pt \
|
||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||
$dir/test_wavs/1089-134686-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run modified beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||
cd egs/librispeech/ASR
|
||||
./pruned_transducer_stateless/pretrained.py \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint $dir/exp/pretrained.pt \
|
||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||
$dir/test_wavs/1089-134686-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0001.wav \
|
||||
$dir/test_wavs/1221-135766-0002.wav
|
||||
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
|
||||
|
82
.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
vendored
Normal file
82
.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
vendored
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
|
||||
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-librispeech-2022-04-19
|
||||
# stateless transducer + torchaudio rnn-t loss
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
run_librispeech_2022_04_19:
|
||||
if: github.event.label.name == 'ready' || github.event_name == 'push'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: '**/requirements-ci.txt'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||
|
||||
- name: Cache kaldifeat
|
||||
id: my-cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/tmp/kaldifeat
|
||||
key: cache-tmp-${{ matrix.python-version }}
|
||||
|
||||
- name: Install kaldifeat
|
||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p ~/tmp
|
||||
cd ~/tmp
|
||||
git clone https://github.com/csukuangfj/kaldifeat
|
||||
cd kaldifeat
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
|
@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install graphviz
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
@ -76,48 +71,11 @@ jobs:
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Download pre-trained model
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
cd egs/librispeech/ASR
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
|
||||
ls -lh tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
|
||||
|
||||
- name: Run CTC decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./conformer_ctc/pretrained.py \
|
||||
--num-classes 500 \
|
||||
--checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/bpe.model \
|
||||
--method ctc-decoding \
|
||||
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
|
||||
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
|
||||
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac
|
||||
|
||||
- name: Run HLG decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./conformer_ctc/pretrained.py \
|
||||
--num-classes 500 \
|
||||
--checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
|
||||
--words-file ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/words.txt \
|
||||
--HLG ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/HLG.pt \
|
||||
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
|
||||
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
|
||||
./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac
|
||||
.github/scripts/run-pre-trained-conformer-ctc.sh
|
||||
|
@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install graphviz
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
@ -76,97 +71,11 @@ jobs:
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Download pre-trained model
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
cd egs/librispeech/ASR
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
|
||||
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 1)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 2)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 2 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 3)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 3 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run modified beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
|
||||
.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
|
||||
|
@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install graphviz
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
@ -76,99 +71,11 @@ jobs:
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Download pre-trained model
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
cd egs/librispeech/ASR
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01
|
||||
|
||||
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 1)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 2)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 2 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 3)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 3 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
|
||||
|
||||
|
||||
- name: Run modified beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless_multi_datasets/pretrained.py \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav
|
||||
.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
|
||||
|
@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install graphviz
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
@ -76,98 +71,11 @@ jobs:
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Download pre-trained model
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
cd egs/aishell/ASR
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01
|
||||
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 1)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified-2/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 2)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified-2/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 2 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 3)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified-2/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 3 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified-2/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
|
||||
|
||||
- name: Run modified beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified-2/pretrained.py \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
|
||||
|
@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install graphviz
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
@ -76,98 +71,11 @@ jobs:
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Download pre-trained model
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
cd egs/aishell/ASR
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
|
||||
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 1)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 2)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 2 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 3)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 3 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
|
||||
|
||||
- name: Run modified beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/aishell/ASR
|
||||
./transducer_stateless_modified/pretrained.py \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
|
||||
--lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
|
||||
./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
|
||||
.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
|
||||
|
@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install graphviz
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
@ -76,96 +71,11 @@ jobs:
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Download pre-trained model
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
cd egs/librispeech/ASR
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 1)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 2)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 2 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 3)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 3 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run modified beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
.github/scripts/run-pre-trained-transducer-stateless.sh
|
||||
|
46
.github/workflows/run-pretrained-transducer.yml
vendored
46
.github/workflows/run-pretrained-transducer.yml
vendored
@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install graphviz
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install graphviz
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
@ -76,48 +71,11 @@ jobs:
|
||||
cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
make -j2 _kaldifeat
|
||||
|
||||
- name: Download pre-trained model
|
||||
- name: Inference with pre-trained model
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get -qq install git-lfs tree sox
|
||||
cd egs/librispeech/ASR
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
|
||||
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer/pretrained.py \
|
||||
--method greedy_search \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
|
||||
.github/scripts/run-pre-trained-transducer.sh
|
||||
|
@ -7,7 +7,8 @@ The following table lists the differences among them.
|
||||
| | Encoder | Decoder | Comment |
|
||||
|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
|
||||
| `transducer` | Conformer | LSTM | |
|
||||
| `transducer_stateless` | Conformer | Embedding + Conv1d | |
|
||||
| `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss |
|
||||
| `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss |
|
||||
| `transducer_lstm` | LSTM | LSTM | |
|
||||
| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
|
||||
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
|
||||
|
@ -134,7 +134,7 @@ This is with a reworked version of the conformer encoder, with many changes.
|
||||
|
||||
#### Training on fulll librispeech
|
||||
|
||||
using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`.
|
||||
Using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`.
|
||||
See <https://github.com/k2-fsa/icefall/pull/288>
|
||||
|
||||
The WERs are:
|
||||
@ -477,6 +477,78 @@ You can find a pretrained model by visiting
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01>
|
||||
|
||||
|
||||
##### 2022-04-19
|
||||
|
||||
[transducer_stateless2](./transducer_stateless2)
|
||||
|
||||
This version uses torchaudio's RNN-T loss.
|
||||
|
||||
Using commit `fce7f3cd9a486405ee008bcbe4999264f27774a3`.
|
||||
See <https://github.com/k2-fsa/icefall/pull/316>
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|-------------------------------------|------------|------------|--------------------------------------------------------------------------------|
|
||||
| greedy search (max sym per frame 1) | 2.65 | 6.30 | --epoch 59 --avg 10 --max-duration 600 |
|
||||
| greedy search (max sym per frame 2) | 2.62 | 6.23 | --epoch 59 --avg 10 --max-duration 100 |
|
||||
| greedy search (max sym per frame 3) | 2.62 | 6.23 | --epoch 59 --avg 10 --max-duration 100 |
|
||||
| modified beam search | 2.63 | 6.15 | --epoch 59 --avg 10 --max-duration 100 --decoding-method modified_beam_search |
|
||||
| beam search | 2.59 | 6.15 | --epoch 59 --avg 10 --max-duration 100 --decoding-method beam_search |
|
||||
|
||||
**Note**: This model is trained with standard RNN-T loss. Neither modified transducer nor pruned RNN-T is used.
|
||||
You can see that there is a performance degradation in WER when we limit the max symbol per frame to 1.
|
||||
|
||||
The number of active paths in `modified_beam_search` and `beam_search` is 4.
|
||||
|
||||
The training and decoding commands are:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
./transducer_stateless2/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 60 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless2/exp-2 \
|
||||
--full-libri 1 \
|
||||
--max-duration 300 \
|
||||
--lr-factor 5
|
||||
|
||||
epoch=59
|
||||
avg=10
|
||||
# greedy search
|
||||
./transducer_stateless2/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir ./transducer_stateless2/exp-2 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--max-sym-per-frame 1
|
||||
|
||||
# modified beam search
|
||||
./transducer_stateless2/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir ./transducer_stateless2/exp-2 \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
|
||||
# beam search
|
||||
./transducer_stateless2/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir ./transducer_stateless2/exp-2 \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
```
|
||||
|
||||
The tensorboard log is at <https://tensorboard.dev/experiment/oAlle3dxQD2EY8ePwjIGuw/>.
|
||||
|
||||
|
||||
You can find a pre-trained model, decoding logs, and decoding results at
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19>
|
||||
|
||||
|
||||
|
||||
##### 2022-02-07
|
||||
|
||||
Using commit `a8150021e01d34ecbd6198fe03a57eacf47a16f2`.
|
||||
|
141
egs/librispeech/ASR/local/compile_lg.py
Executable file
141
egs/librispeech/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates LG from
|
||||
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated LG is saved in $lang_dir/LG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
|
||||
Return:
|
||||
An FSA representing LG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
logging.info("Loading pre-compiled G_3_gram")
|
||||
d = torch.load("data/lm/G_3_gram.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
return LG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "LG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir)
|
||||
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -242,3 +242,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
./local/compile_hlg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
# Compile LG for RNN-T fast_beam_search decoding
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Compile LG"
|
||||
./local/compile_lg.py --lang-dir data/lang_phone
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
@ -22,7 +22,7 @@ import k2
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@ def fast_beam_search(
|
||||
beam: float,
|
||||
max_states: int,
|
||||
max_contexts: int,
|
||||
use_max: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -53,6 +54,9 @@ def fast_beam_search(
|
||||
Max states per stream per frame.
|
||||
max_contexts:
|
||||
Max contexts pre stream per frame.
|
||||
use_max:
|
||||
True to use max operation to select the hypothesis with the largest
|
||||
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
@ -104,9 +108,67 @@ def fast_beam_search(
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
if use_max:
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
else:
|
||||
num_paths = 200
|
||||
use_double_scores = True
|
||||
nbest_scale = 0.8
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# The following code is modified from nbest.intersect()
|
||||
word_fsa = k2.invert(nbest.fsa)
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# delete token IDs as it is not needed
|
||||
del word_fsa.aux_labels
|
||||
word_fsa.scores.zero_()
|
||||
|
||||
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
||||
path_to_utt_map = nbest.shape.row_ids(1)
|
||||
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||
inv_lattice = k2.invert(lattice)
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
else:
|
||||
inv_lattice = k2.arc_sort(lattice)
|
||||
|
||||
if inv_lattice.shape[0] == 1:
|
||||
path_lattice = k2.intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
||||
sorted_match_a=True,
|
||||
)
|
||||
else:
|
||||
path_lattice = k2.intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_utt_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
|
||||
# path_lattice has word IDs as labels and token IDs as aux_labels
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
|
||||
tot_scores = path_lattice.get_tot_scores(
|
||||
use_double_scores=use_double_scores, log_semiring=True
|
||||
)
|
||||
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||
hyps = get_texts(best_path)
|
||||
return hyps
|
||||
|
||||
|
||||
def greedy_search(
|
||||
@ -280,7 +342,7 @@ class HypothesisList(object):
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
def add(self, hyp: Hypothesis, use_max: bool = False) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
@ -289,13 +351,20 @@ class HypothesisList(object):
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be added.
|
||||
use_max:
|
||||
True to select the hypothesis with the larger log_prob in case there
|
||||
already exists a hypothesis whose `ys` equals to `hyp.ys`.
|
||||
False to use log_add.
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
if use_max:
|
||||
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
|
||||
else:
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
@ -403,6 +472,7 @@ def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
use_max: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
@ -413,6 +483,9 @@ def modified_beam_search(
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
use_max:
|
||||
True to use max operation to select the hypothesis with the largest
|
||||
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||
for the i-th utterance.
|
||||
@ -432,7 +505,8 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
for t in range(T):
|
||||
@ -517,6 +591,7 @@ def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
use_max: bool = False,
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
@ -532,6 +607,9 @@ def _deprecated_modified_beam_search(
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
use_max:
|
||||
True to use max operation to select the hypothesis with the largest
|
||||
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
@ -553,12 +631,13 @@ def _deprecated_modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
current_encoder_out = encoder_out[:, t:t + 1, :].unsqueeze(2)
|
||||
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
A = list(B)
|
||||
@ -611,7 +690,7 @@ def _deprecated_modified_beam_search(
|
||||
new_ys.append(new_token)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B.add(new_hyp)
|
||||
B.add(new_hyp, use_max=use_max)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
@ -623,6 +702,7 @@ def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
use_max: bool = False,
|
||||
) -> List[int]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
@ -636,6 +716,9 @@ def beam_search(
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
use_max:
|
||||
True to use max operation to select the hypothesis with the largest
|
||||
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
@ -661,7 +744,9 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
||||
B.add(
|
||||
Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max
|
||||
)
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -720,7 +805,10 @@ def beam_search(
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
B.add(
|
||||
Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
@ -729,7 +817,10 @@ def beam_search(
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
A.add(
|
||||
Hypothesis(ys=new_ys, log_prob=new_log_prob),
|
||||
use_max=use_max,
|
||||
)
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
|
@ -53,6 +53,19 @@ Usage:
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
|
||||
(5) fast beam search using LG
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--use-LG True \
|
||||
--use-max False \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 8 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
@ -81,10 +94,12 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -136,6 +151,13 @@ def get_parser():
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@ -167,6 +189,36 @@ def get_parser():
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-LG",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use an LG graph for FSA-based beam search.
|
||||
Used only when --decoding_method is fast_beam_search. If setting true,
|
||||
it assumes there is an LG.pt file in lang_dir.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-max",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, use max-op to select the hypothesis that have the
|
||||
max log_prob in case of duplicate hypotheses.
|
||||
If False, use log_add.
|
||||
Used only for beam_search, modified_beam_search, and fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
@ -206,6 +258,7 @@ def decode_one_batch(
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -229,6 +282,8 @@ def decode_one_batch(
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
@ -260,9 +315,14 @@ def decode_one_batch(
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
use_max=params.use_max,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
if params.use_LG:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
else:
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
@ -278,6 +338,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
use_max=params.use_max,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -299,6 +360,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
use_max=params.use_max,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -325,6 +387,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
@ -338,6 +401,8 @@ def decode_dataset(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
@ -368,8 +433,9 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -460,13 +526,16 @@ def main():
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-use-LG-{params.use_LG}"
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
params.suffix += f"-use-max-{params.use_max}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-use-max-{params.use_max}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -527,9 +596,21 @@ def main():
|
||||
model.device = device
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
if params.use_LG:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/LG.pt", map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -551,6 +632,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
|
@ -811,13 +811,23 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
try:
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
logging.info(
|
||||
f"Before removing short and long utterances: {num_in_total}"
|
||||
)
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(
|
||||
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
|
||||
)
|
||||
except TypeError as e:
|
||||
# You can ignore this error as previous versions of Lhotse work fine
|
||||
# for the above code. In recent versions of Lhotse, it uses
|
||||
# lazy filter, producing cutsets that don't have the __len__ method
|
||||
logging.info(str(e))
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
|
@ -98,27 +98,28 @@ def get_parser():
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
""",
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -453,13 +454,19 @@ def main():
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -476,8 +483,9 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
@ -485,8 +493,20 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--use_fp16 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
@ -89,9 +89,9 @@ class Decoder(nn.Module):
|
||||
- (h, c), containing the state information for LSTM layers.
|
||||
Both are of shape (num_layers, N, C)
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embeding_out = self.embedding_dropout(embeding_out)
|
||||
rnn_out, (h, c) = self.rnn(embeding_out, states)
|
||||
embedding_out = self.embedding(y)
|
||||
embedding_out = self.embedding_dropout(embedding_out)
|
||||
rnn_out, (h, c) = self.rnn(embedding_out, states)
|
||||
out = self.output_linear(rnn_out)
|
||||
|
||||
return out, (h, c)
|
||||
|
@ -93,9 +93,9 @@ class Decoder(nn.Module):
|
||||
- (h, c), containing the state information for LSTM layers.
|
||||
Both are of shape (num_layers, N, C)
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embeding_out = self.embedding_dropout(embeding_out)
|
||||
rnn_out, (h, c) = self.rnn(embeding_out, states)
|
||||
embedding_out = self.embedding(y)
|
||||
embedding_out = self.embedding_dropout(embedding_out)
|
||||
rnn_out, (h, c) = self.rnn(embedding_out, states)
|
||||
out = self.output_linear(rnn_out)
|
||||
|
||||
return out, (h, c)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@ -505,8 +506,10 @@ def modified_beam_search(
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
@ -613,8 +616,10 @@ def _deprecated_modified_beam_search(
|
||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
|
@ -653,13 +653,23 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
try:
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
logging.info(
|
||||
f"Before removing short and long utterances: {num_in_total}"
|
||||
)
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(
|
||||
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
|
||||
)
|
||||
except TypeError as e:
|
||||
# You can ignore this error as previous versions of Lhotse work fine
|
||||
# for the above code. In recent versions of Lhotse, it uses
|
||||
# lazy filter, producing cutsets that don't have the __len__ method
|
||||
logging.info(str(e))
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
|
1
egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../tdnn_lstm_ctc/asr_datamodule.py
|
1
egs/librispeech/ASR/transducer_stateless2/beam_search.py
Symbolic link
1
egs/librispeech/ASR/transducer_stateless2/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/beam_search.py
|
1
egs/librispeech/ASR/transducer_stateless2/conformer.py
Symbolic link
1
egs/librispeech/ASR/transducer_stateless2/conformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/conformer.py
|
443
egs/librispeech/ASR/transducer_stateless2/decode.py
Executable file
443
egs/librispeech/ASR/transducer_stateless2/decode.py
Executable file
@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./transducer_stateless2/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./transducer_stateless2/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./transducer_stateless2/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=29,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=13,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
beam_search or modified_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = model.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
hyp_list: List[List[int]] = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_list = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_list = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyp_list.append(hyp)
|
||||
|
||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
else:
|
||||
return {f"beam_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 100
|
||||
else:
|
||||
log_interval = 2
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/librispeech/ASR/transducer_stateless2/decoder.py
Symbolic link
1
egs/librispeech/ASR/transducer_stateless2/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/decoder.py
|
1
egs/librispeech/ASR/transducer_stateless2/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/transducer_stateless2/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/encoder_interface.py
|
181
egs/librispeech/ASR/transducer_stateless2/export.py
Executable file
181
egs/librispeech/ASR/transducer_stateless2/export.py
Executable file
@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
./transducer_stateless2/export.py \
|
||||
--exp-dir ./transducer_stateless2/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `transducer_stateless2/decode.py`, you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./transducer_stateless2/decode.py \
|
||||
--exp-dir ./transducer_stateless2/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 1 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless2/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
67
egs/librispeech/ASR/transducer_stateless2/joiner.py
Normal file
67
egs/librispeech/ASR/transducer_stateless2/joiner.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
*unused,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||
unused:
|
||||
This is a placeholder so that we can reuse
|
||||
transducer_stateless/beam_search.py in this folder as that
|
||||
script assumes the joiner networks accepts 4 inputs.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, U, self.output_dim).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||
assert encoder_out.size(0) == decoder_out.size(0)
|
||||
assert encoder_out.size(2) == self.input_dim
|
||||
assert decoder_out.size(2) == self.input_dim
|
||||
|
||||
encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C)
|
||||
decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C)
|
||||
x = encoder_out + decoder_out # (N, T, U, C)
|
||||
|
||||
activations = torch.tanh(x)
|
||||
|
||||
logits = self.output_linear(activations)
|
||||
|
||||
if not self.training:
|
||||
# We reuse the beam_search.py from transducer_stateless,
|
||||
# which expects that the joiner network outputs
|
||||
# a 2-D tensor.
|
||||
logits = logits.squeeze(2).squeeze(1)
|
||||
|
||||
return logits
|
130
egs/librispeech/ASR/transducer_stateless2/model.py
Normal file
130
egs/librispeech/ASR/transducer_stateless2/model.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Note we use `rnnt_loss` from torchaudio, which exists only in
|
||||
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
|
||||
"""
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
import torchaudio.functional
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, C) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, C). It should contain
|
||||
one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
logits = self.joiner(
|
||||
encoder_out=encoder_out,
|
||||
decoder_out=decoder_out,
|
||||
)
|
||||
|
||||
# rnnt_loss requires 0 padded targets
|
||||
# Note: y does not start with SOS
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
||||
"Please install a version >= 0.10.0"
|
||||
)
|
||||
|
||||
loss = torchaudio.functional.rnnt_loss(
|
||||
logits=logits,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return loss
|
293
egs/librispeech/ASR/transducer_stateless2/pretrained.py
Executable file
293
egs/librispeech/ASR/transducer_stateless2/pretrained.py
Executable file
@ -0,0 +1,293 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) greedy search
|
||||
./transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
(2) beam search
|
||||
./transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
(3) modified beam search
|
||||
./transducer_stateless2/pretrained.py \
|
||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
You can also use `./transducer_stateless2/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./transducer_stateless2/exp/pretrained.pt is generated by
|
||||
./transducer_stateless2/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Used only when --method is beam_search and modified_beam_search ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyp_list = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_list = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_list = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
hyp_list.append(hyp)
|
||||
|
||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/transducer_stateless2/subsampling.py
Symbolic link
1
egs/librispeech/ASR/transducer_stateless2/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/subsampling.py
|
779
egs/librispeech/ASR/transducer_stateless2/train.py
Executable file
779
egs/librispeech/ASR/transducer_stateless2/train.py
Executable file
@ -0,0 +1,779 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./transducer_stateless2/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 250 \
|
||||
--lr-factor 2.5
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless2/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-factor",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="The lr_factor for Noam optimizer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- attention_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
# parameters for Noam
|
||||
"warm_step": 80000, # For the 100h subset, use 8k
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
blank_id=params.blank_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
Args:
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of Conformer in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
"""
|
||||
device = model.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 800
|
||||
params.warm_step = 8000
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank])
|
||||
model.device = device
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
model_size=params.attention_dim,
|
||||
factor=params.lr_factor,
|
||||
warm_step=params.warm_step,
|
||||
)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
num_in_total = len(train_cuts)
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
try:
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(
|
||||
f"Before removing short and long utterances: {num_in_total}"
|
||||
)
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(
|
||||
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
|
||||
)
|
||||
except TypeError as e:
|
||||
# You can ignore this error as previous versions of Lhotse work fine
|
||||
# for the above code. In recent versions of Lhotse, it uses
|
||||
# lazy filter, producing cutsets that don't have the __len__ method
|
||||
logging.info(str(e))
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
fix_random_seed(params.seed + epoch)
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
logging.info(
|
||||
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
||||
)
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
optimizer.zero_grad()
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
"max_duration setting. We recommend decreasing "
|
||||
"max_duration and trying again.\n"
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/librispeech/ASR/transducer_stateless2/transformer.py
Symbolic link
1
egs/librispeech/ASR/transducer_stateless2/transformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/transformer.py
|
@ -84,9 +84,9 @@ class Decoder(nn.Module):
|
||||
- (h, c), which contain the state information for RNN layers.
|
||||
Both are of shape (num_layers, N, C)
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embeding_out = self.embedding_dropout(embeding_out)
|
||||
rnn_out, (h, c) = self.rnn(embeding_out, states)
|
||||
embedding_out = self.embedding(y)
|
||||
embedding_out = self.embedding_dropout(embedding_out)
|
||||
rnn_out, (h, c) = self.rnn(embedding_out, states)
|
||||
out = self.output_linear(rnn_out)
|
||||
|
||||
return out, (h, c)
|
||||
|
@ -150,12 +150,25 @@ def average_checkpoints(
|
||||
n = len(filenames)
|
||||
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
for k in avg:
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in avg:
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user