mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'master' into context_biasing
This commit is contained in:
commit
4eb356ce49
@ -15,5 +15,5 @@ mkdir -p data
|
|||||||
cd data
|
cd data
|
||||||
[ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank
|
[ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank
|
||||||
cd ..
|
cd ..
|
||||||
./local/compute_fbank_librispeech.py
|
./local/compute_fbank_librispeech.py --dataset 'test-clean test-other'
|
||||||
ls -lh data/fbank/
|
ls -lh data/fbank/
|
||||||
|
|||||||
@ -25,7 +25,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -18,7 +18,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -20,7 +20,6 @@ abs_repo=$(realpath $repo)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
|
|||||||
@ -23,7 +23,6 @@ popd
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -22,7 +22,6 @@ popd
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -18,7 +18,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -10,7 +10,7 @@ log() {
|
|||||||
|
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2022-12-14
|
repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
|
||||||
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
log "Downloading pre-trained model from $repo_url"
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
@ -18,7 +18,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
|
|||||||
@ -18,7 +18,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.flac
|
|
||||||
ls -lh $repo/test_wavs/*.flac
|
ls -lh $repo/test_wavs/*.flac
|
||||||
|
|
||||||
log "CTC decoding"
|
log "CTC decoding"
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
|
|||||||
@ -19,7 +19,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
log "Beam search decoding"
|
log "Beam search decoding"
|
||||||
|
|||||||
@ -20,7 +20,6 @@ repo=$(basename $repo_url)
|
|||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
|
|||||||
67
.github/scripts/test-ncnn-export.sh
vendored
67
.github/scripts/test-ncnn-export.sh
vendored
@ -232,70 +232,3 @@ python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
|
|||||||
|
|
||||||
rm -rf $repo
|
rm -rf $repo
|
||||||
log "--------------------------------------------------------------------------"
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
# Go back to the root directory of icefall repo
|
|
||||||
popd
|
|
||||||
|
|
||||||
pushd egs/csj/ASR
|
|
||||||
|
|
||||||
log "=========================================================================="
|
|
||||||
repo_url=https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|
||||||
repo=$(basename $repo_url)
|
|
||||||
|
|
||||||
pushd $repo
|
|
||||||
git lfs pull --include "exp_fluent/pretrained.pt"
|
|
||||||
git lfs pull --include "exp_disfluent/pretrained.pt"
|
|
||||||
|
|
||||||
cd exp_fluent
|
|
||||||
ln -s pretrained.pt epoch-99.pt
|
|
||||||
|
|
||||||
cd ../exp_disfluent
|
|
||||||
ln -s pretrained.pt epoch-99.pt
|
|
||||||
|
|
||||||
cd ../test_wavs
|
|
||||||
git lfs pull --include "*.wav"
|
|
||||||
popd
|
|
||||||
|
|
||||||
log "Export via torch.jit.trace()"
|
|
||||||
|
|
||||||
for exp in exp_fluent exp_disfluent; do
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
|
||||||
--exp-dir $repo/$exp/ \
|
|
||||||
--lang $repo/data/lang_char \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--use-averaged-model 0 \
|
|
||||||
\
|
|
||||||
--decode-chunk-len 32 \
|
|
||||||
--num-left-chunks 4 \
|
|
||||||
--num-encoder-layers "2,4,3,2,4" \
|
|
||||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
|
||||||
--nhead "8,8,8,8,8" \
|
|
||||||
--encoder-dims "384,384,384,384,384" \
|
|
||||||
--attention-dims "192,192,192,192,192" \
|
|
||||||
--encoder-unmasked-dims "256,256,256,256,256" \
|
|
||||||
--zipformer-downsampling-factors "1,2,4,8,2" \
|
|
||||||
--cnn-module-kernels "31,31,31,31,31" \
|
|
||||||
--decoder-dim 512 \
|
|
||||||
--joiner-dim 512
|
|
||||||
|
|
||||||
pnnx $repo/$exp/encoder_jit_trace-pnnx.pt
|
|
||||||
pnnx $repo/$exp/decoder_jit_trace-pnnx.pt
|
|
||||||
pnnx $repo/$exp/joiner_jit_trace-pnnx.pt
|
|
||||||
|
|
||||||
for wav in aps-smp.wav interview_aps-smp.wav reproduction-smp.wav sps-smp.wav; do
|
|
||||||
python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
|
|
||||||
--tokens $repo/data/lang_char/tokens.txt \
|
|
||||||
--encoder-param-filename $repo/$exp/encoder_jit_trace-pnnx.ncnn.param \
|
|
||||||
--encoder-bin-filename $repo/$exp/encoder_jit_trace-pnnx.ncnn.bin \
|
|
||||||
--decoder-param-filename $repo/$exp/decoder_jit_trace-pnnx.ncnn.param \
|
|
||||||
--decoder-bin-filename $repo/$exp/decoder_jit_trace-pnnx.ncnn.bin \
|
|
||||||
--joiner-param-filename $repo/$exp/joiner_jit_trace-pnnx.ncnn.param \
|
|
||||||
--joiner-bin-filename $repo/$exp/joiner_jit_trace-pnnx.ncnn.bin \
|
|
||||||
$repo/test_wavs/$wav
|
|
||||||
done
|
|
||||||
done
|
|
||||||
|
|
||||||
rm -rf $repo
|
|
||||||
log "--------------------------------------------------------------------------"
|
|
||||||
|
|||||||
4
.github/workflows/run-aishell-2022-06-20.yml
vendored
4
.github/workflows/run-aishell-2022-06-20.yml
vendored
@ -65,7 +65,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -87,7 +87,7 @@ jobs:
|
|||||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -60,7 +60,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -119,7 +119,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -35,7 +35,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_2022_12_15_zipformer_ctc_bs:
|
run_librispeech_2022_12_15_zipformer_ctc_bs:
|
||||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -60,7 +60,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -119,7 +119,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -47,7 +47,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -106,7 +106,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -64,7 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -123,7 +123,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -54,7 +54,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -73,7 +73,7 @@ jobs:
|
|||||||
- name: Inference with pre-trained model
|
- name: Inference with pre-trained model
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -63,7 +63,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -122,7 +122,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -63,7 +63,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -122,7 +122,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -54,7 +54,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -73,7 +73,7 @@ jobs:
|
|||||||
- name: Inference with pre-trained model
|
- name: Inference with pre-trained model
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -54,7 +54,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -73,7 +73,7 @@ jobs:
|
|||||||
- name: Inference with pre-trained model
|
- name: Inference with pre-trained model
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -63,7 +63,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -122,7 +122,7 @@ jobs:
|
|||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
@ -54,7 +54,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -73,7 +73,7 @@ jobs:
|
|||||||
- name: Inference with pre-trained model
|
- name: Inference with pre-trained model
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
2
.github/workflows/run-ptb-rnn-lm.yml
vendored
2
.github/workflows/run-ptb-rnn-lm.yml
vendored
@ -47,7 +47,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Prepare data
|
- name: Prepare data
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@ -54,7 +54,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
@ -76,7 +76,7 @@ jobs:
|
|||||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|||||||
2
.github/workflows/run-yesno-recipe.yml
vendored
2
.github/workflows/run-yesno-recipe.yml
vendored
@ -67,7 +67,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Run yesno recipe
|
- name: Run yesno recipe
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
2
.github/workflows/test-ncnn-export.yml
vendored
2
.github/workflows/test-ncnn-export.yml
vendored
@ -46,7 +46,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
|
|||||||
2
.github/workflows/test-onnx-export.yml
vendored
2
.github/workflows/test-onnx-export.yml
vendored
@ -46,7 +46,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
- name: Cache kaldifeat
|
||||||
id: my-cache
|
id: my-cache
|
||||||
|
|||||||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@ -56,7 +56,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
|
sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
|
||||||
sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all
|
sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
@ -70,7 +70,7 @@ jobs:
|
|||||||
pip install git+https://github.com/lhotse-speech/lhotse
|
pip install git+https://github.com/lhotse-speech/lhotse
|
||||||
# icefall requirements
|
# icefall requirements
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
pip install kaldifst
|
pip install kaldifst
|
||||||
pip install onnxruntime
|
pip install onnxruntime
|
||||||
@ -119,8 +119,8 @@ jobs:
|
|||||||
cd ../transducer_stateless
|
cd ../transducer_stateless
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer
|
# cd ../transducer
|
||||||
pytest -v -s
|
# pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_stateless2
|
cd ../transducer_stateless2
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
@ -157,8 +157,8 @@ jobs:
|
|||||||
cd ../transducer_stateless
|
cd ../transducer_stateless
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer
|
# cd ../transducer
|
||||||
pytest -v -s
|
# pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_stateless2
|
cd ../transducer_stateless2
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|||||||
@ -391,18 +391,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -412,9 +408,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
164
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file
164
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
|
||||||
|
# Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes a `tokens.txt` and a text file such as
|
||||||
|
./download/lm/aishell-transcript.txt
|
||||||
|
and outputs the LM training data to a supplied directory such
|
||||||
|
as data/lm_training_char. The format is as follows:
|
||||||
|
It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
|
||||||
|
representation of a dict with the same format with librispeech receipe
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-char",
|
||||||
|
type=str,
|
||||||
|
help="""Lang dir of asr model, e.g. data/lang_char""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-data",
|
||||||
|
type=str,
|
||||||
|
help="""Input LM training data as text, e.g.
|
||||||
|
download/lm/aishell-train-word.txt""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-archive",
|
||||||
|
type=str,
|
||||||
|
help="""Path to output archive, e.g. data/lm_training_char/lm_data.pt;
|
||||||
|
look at the source of this script to see the format.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
if Path(args.lm_archive).exists():
|
||||||
|
logging.warning(f"{args.lm_archive} exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
# make token_dict from tokens.txt in order to map characters to tokens.
|
||||||
|
token_dict = {}
|
||||||
|
token_file = args.lang_char + "/tokens.txt"
|
||||||
|
|
||||||
|
with open(token_file, "r") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
line_list = line.split()
|
||||||
|
token_dict[line_list[0]] = int(line_list[1])
|
||||||
|
|
||||||
|
# word2index is a dictionary from words to integer ids. No need to reserve
|
||||||
|
# space for epsilon, etc.; the words are just used as a convenient way to
|
||||||
|
# compress the sequences of tokens.
|
||||||
|
word2index = dict()
|
||||||
|
|
||||||
|
word2token = [] # Will be a list-of-list-of-int, representing tokens.
|
||||||
|
sentences = [] # Will be a list-of-list-of-int, representing word-ids.
|
||||||
|
|
||||||
|
if "aishell-lm" in args.lm_data:
|
||||||
|
num_lines_in_total = 120098.0
|
||||||
|
step = 50000
|
||||||
|
elif "valid" in args.lm_data:
|
||||||
|
num_lines_in_total = 14326.0
|
||||||
|
step = 3000
|
||||||
|
elif "test" in args.lm_data:
|
||||||
|
num_lines_in_total = 7176.0
|
||||||
|
step = 3000
|
||||||
|
else:
|
||||||
|
num_lines_in_total = None
|
||||||
|
step = None
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
with open(args.lm_data) as f:
|
||||||
|
while True:
|
||||||
|
line = f.readline()
|
||||||
|
if line == "":
|
||||||
|
break
|
||||||
|
|
||||||
|
if step and processed % step == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Processed number of lines: {processed} "
|
||||||
|
f"({processed / num_lines_in_total * 100: .3f}%)"
|
||||||
|
)
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
line_words = line.split()
|
||||||
|
for w in line_words:
|
||||||
|
if w not in word2index:
|
||||||
|
w_token = []
|
||||||
|
for t in w:
|
||||||
|
if t in token_dict:
|
||||||
|
w_token.append(token_dict[t])
|
||||||
|
else:
|
||||||
|
w_token.append(token_dict["<unk>"])
|
||||||
|
word2index[w] = len(word2token)
|
||||||
|
word2token.append(w_token)
|
||||||
|
sentences.append([word2index[w] for w in line_words])
|
||||||
|
|
||||||
|
logging.info("Constructing ragged tensors")
|
||||||
|
words = k2.ragged.RaggedTensor(word2token)
|
||||||
|
sentences = k2.ragged.RaggedTensor(sentences)
|
||||||
|
|
||||||
|
output = dict(words=words, sentences=sentences)
|
||||||
|
|
||||||
|
num_sentences = sentences.dim0
|
||||||
|
logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}")
|
||||||
|
sentence_lengths = [0] * num_sentences
|
||||||
|
for i in range(num_sentences):
|
||||||
|
if step and i % step == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Processed number of lines: {i} ({i / num_sentences * 100: .3f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
word_ids = sentences[i]
|
||||||
|
|
||||||
|
# NOTE: If word_ids is a tensor with only 1 entry,
|
||||||
|
# token_ids is a torch.Tensor
|
||||||
|
token_ids = words[word_ids]
|
||||||
|
if isinstance(token_ids, k2.RaggedTensor):
|
||||||
|
token_ids = token_ids.values
|
||||||
|
|
||||||
|
# token_ids is a 1-D tensor containing the BPE tokens
|
||||||
|
# of the current sentence
|
||||||
|
|
||||||
|
sentence_lengths[i] = token_ids.numel()
|
||||||
|
|
||||||
|
output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32)
|
||||||
|
|
||||||
|
torch.save(output, args.lm_archive)
|
||||||
|
logging.info(f"Saved to {args.lm_archive}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
||||||
@ -7,7 +7,7 @@ set -eou pipefail
|
|||||||
|
|
||||||
nj=15
|
nj=15
|
||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=10
|
stop_stage=11
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
# We assume dl_dir (download dir) contains the following
|
||||||
# directories and files. If not, they will be downloaded
|
# directories and files. If not, they will be downloaded
|
||||||
@ -219,3 +219,93 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|||||||
./local/compile_hlg.py --lang-dir $lang_phone_dir
|
./local/compile_hlg.py --lang-dir $lang_phone_dir
|
||||||
./local/compile_hlg.py --lang-dir $lang_char_dir
|
./local/compile_hlg.py --lang-dir $lang_char_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||||
|
log "Stage 9: Generate LM training data"
|
||||||
|
|
||||||
|
log "Processing char based data"
|
||||||
|
out_dir=data/lm_training_char
|
||||||
|
mkdir -p $out_dir $dl_dir/lm
|
||||||
|
|
||||||
|
if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
|
||||||
|
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $dl_dir/lm/aishell-train-word.txt \
|
||||||
|
--lm-archive $out_dir/lm_data.pt
|
||||||
|
|
||||||
|
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
|
||||||
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
|
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
|
||||||
|
find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid
|
||||||
|
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text |
|
||||||
|
cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $dl_dir/lm/aishell-valid-word.txt \
|
||||||
|
--lm-archive $out_dir/lm_data_valid.pt
|
||||||
|
|
||||||
|
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
|
||||||
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
|
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
|
||||||
|
find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid
|
||||||
|
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text |
|
||||||
|
cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $dl_dir/lm/aishell-test-word.txt \
|
||||||
|
--lm-archive $out_dir/lm_data_test.pt
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||||
|
log "Stage 10: Sort LM training data"
|
||||||
|
# Sort LM training data by sentence length in descending order
|
||||||
|
# for ease of training.
|
||||||
|
#
|
||||||
|
# Sentence length equals to the number of tokens
|
||||||
|
# in a sentence.
|
||||||
|
|
||||||
|
out_dir=data/lm_training_char
|
||||||
|
mkdir -p $out_dir
|
||||||
|
ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $out_dir/lm_data.pt \
|
||||||
|
--out-lm-data $out_dir/sorted_lm_data.pt \
|
||||||
|
--out-statistics $out_dir/statistics.txt
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $out_dir/lm_data_valid.pt \
|
||||||
|
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
|
||||||
|
--out-statistics $out_dir/statistics-valid.txt
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $out_dir/lm_data_test.pt \
|
||||||
|
--out-lm-data $out_dir/sorted_lm_data-test.pt \
|
||||||
|
--out-statistics $out_dir/statistics-test.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||||
|
log "Stage 11: Train RNN LM model"
|
||||||
|
python ../../../icefall/rnn_lm/train.py \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--world-size 1 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--embedding-dim 512 \
|
||||||
|
--hidden-dim 512 \
|
||||||
|
--num-layers 2 \
|
||||||
|
--batch-size 400 \
|
||||||
|
--exp-dir rnnlm_char/exp \
|
||||||
|
--lm-data data/lm_training_char/sorted_lm_data.pt \
|
||||||
|
--lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \
|
||||||
|
--vocab-size 4336 \
|
||||||
|
--master-port 12345
|
||||||
|
fi
|
||||||
|
|||||||
@ -388,18 +388,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
@ -413,9 +409,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -406,18 +406,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
@ -431,9 +427,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -325,17 +325,13 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
@ -349,9 +345,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -370,18 +370,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
@ -395,9 +391,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -374,18 +374,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
@ -399,9 +395,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -543,18 +543,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -564,9 +560,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -406,18 +406,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -427,9 +423,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -391,18 +391,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -412,9 +408,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -462,18 +462,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -483,9 +479,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -478,17 +478,13 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
test_set_cers = dict()
|
test_set_cers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
wers_filename = (
|
wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(wers_filename, "w") as f:
|
with open(wers_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -499,9 +495,7 @@ def save_results(
|
|||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
cers_filename = (
|
cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(cers_filename, "w") as f:
|
with open(cers_filename, "w") as f:
|
||||||
cer = write_error_stats(
|
cer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
@ -512,9 +506,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
|
test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
|
||||||
test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
|
test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER\tCER", file=f)
|
print("settings\tWER\tCER", file=f)
|
||||||
for key in test_set_wers:
|
for key in test_set_wers:
|
||||||
|
|||||||
@ -599,9 +599,7 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
|
||||||
@ -609,9 +607,7 @@ def save_results(
|
|||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -621,9 +617,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -399,9 +399,7 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = post_processing(results)
|
results = post_processing(results)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
@ -409,9 +407,7 @@ def save_results(
|
|||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -421,9 +417,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -540,6 +540,10 @@ for m in greedy_search fast_beam_search modified_beam_search ; do
|
|||||||
done
|
done
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note that a small change is made to the `pruned_transducer_stateless7/decoder.py` in
|
||||||
|
this [PR](/ceph-data4/yangxiaoyu/softwares/icefall_development/icefall_random_padding/egs/librispeech/ASR/pruned_transducer_stateless7/exp_960h_no_paddingidx_ngpu4/tensorboard) to address the
|
||||||
|
problem of emitting the first symbol at the very beginning. If you need a
|
||||||
|
model without this issue, please download the model from here: <https://huggingface.co/marcoyang/icefall-asr-librispeech-pruned-transducer-stateless7-2023-03-10>
|
||||||
|
|
||||||
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
|
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
|
||||||
|
|
||||||
|
|||||||
@ -728,18 +728,14 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
test_set_delays = dict()
|
test_set_delays = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f,
|
f,
|
||||||
@ -754,9 +750,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -765,8 +759,7 @@ def save_results(
|
|||||||
# sort according to the mean start symbol delay
|
# sort according to the mean start symbol delay
|
||||||
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
|
||||||
delays_info = (
|
delays_info = (
|
||||||
params.res_dir
|
params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(delays_info, "w") as f:
|
with open(delays_info, "w") as f:
|
||||||
print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
|
print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
|
||||||
|
|||||||
@ -432,18 +432,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -453,9 +449,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -750,17 +750,13 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -770,9 +766,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -432,18 +432,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -453,9 +449,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -750,17 +750,13 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -770,9 +766,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
85
egs/librispeech/ASR/finetune.sh
Executable file
85
egs/librispeech/ASR/finetune.sh
Executable file
@ -0,0 +1,85 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
# This is an example script for fine-tuning. Here, we fine-tune a model trained
|
||||||
|
# on Librispeech on GigaSpeech. The model used for fine-tuning is
|
||||||
|
# pruned_transducer_stateless7 (zipformer). If you want to fine-tune model
|
||||||
|
# from another recipe, you can adapt ./pruned_transducer_stateless7/finetune.py
|
||||||
|
# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues.
|
||||||
|
|
||||||
|
# We assume that you have already prepared the GigaSpeech manfiest&features under ./data.
|
||||||
|
# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/gigaspeech/ASR/prepare.sh.
|
||||||
|
|
||||||
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
|
log "Stage -1: Download Pre-trained model"
|
||||||
|
|
||||||
|
# clone from huggingface
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Stage 0: Start fine-tuning"
|
||||||
|
|
||||||
|
# The following configuration of lr schedule should work well
|
||||||
|
# You may also tune the following parameters to adjust learning rate schedule
|
||||||
|
base_lr=0.005
|
||||||
|
lr_epochs=100
|
||||||
|
lr_batches=100000
|
||||||
|
|
||||||
|
# We recommend to start from an averaged model
|
||||||
|
finetune_ckpt=icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained.pt
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/finetune.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--master-port 18180 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless7/exp_giga_finetune \
|
||||||
|
--subset S \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--base-lr $base_lr \
|
||||||
|
--lr-epochs $lr_epochs \
|
||||||
|
--lr-batches $lr_batches \
|
||||||
|
--bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \
|
||||||
|
--do-finetune True \
|
||||||
|
--finetune-ckpt $finetune_ckpt \
|
||||||
|
--max-duration 500
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Decoding"
|
||||||
|
|
||||||
|
epoch=15
|
||||||
|
avg=10
|
||||||
|
|
||||||
|
for m in greedy_search modified_beam_search; do
|
||||||
|
python pruned_transducer_stateless7/decode_gigaspeech.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--use-averaged-model True \
|
||||||
|
--beam-size 4 \
|
||||||
|
--exp-dir pruned_transducer_stateless7/exp_giga_finetune \
|
||||||
|
--max-duration 400 \
|
||||||
|
--decoding-method $m
|
||||||
|
done
|
||||||
|
fi
|
||||||
@ -54,10 +54,20 @@ def get_args():
|
|||||||
help="""Path to the bpe.model. If not None, we will remove short and
|
help="""Path to the bpe.model. If not None, we will remove short and
|
||||||
long utterances before extracting features""",
|
long utterances before extracting features""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_librispeech(bpe_model: Optional[str] = None):
|
def compute_fbank_librispeech(
|
||||||
|
bpe_model: Optional[str] = None,
|
||||||
|
dataset: Optional[str] = None,
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -68,15 +78,19 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None):
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(bpe_model)
|
sp.load(bpe_model)
|
||||||
|
|
||||||
dataset_parts = (
|
if dataset is None:
|
||||||
"dev-clean",
|
dataset_parts = (
|
||||||
"dev-other",
|
"dev-clean",
|
||||||
"test-clean",
|
"dev-other",
|
||||||
"test-other",
|
"test-clean",
|
||||||
"train-clean-100",
|
"test-other",
|
||||||
"train-clean-360",
|
"train-clean-100",
|
||||||
"train-other-500",
|
"train-clean-360",
|
||||||
)
|
"train-other-500",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataset_parts = dataset.split(" ", -1)
|
||||||
|
|
||||||
prefix = "librispeech"
|
prefix = "librispeech"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
@ -131,4 +145,4 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
args = get_args()
|
args = get_args()
|
||||||
logging.info(vars(args))
|
logging.info(vars(args))
|
||||||
compute_fbank_librispeech(bpe_model=args.bpe_model)
|
compute_fbank_librispeech(bpe_model=args.bpe_model, dataset=args.dataset)
|
||||||
|
|||||||
@ -566,18 +566,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -587,9 +583,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -742,17 +742,13 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -762,9 +758,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -702,18 +702,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -723,9 +719,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -611,18 +611,14 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
test_set_delays = dict()
|
test_set_delays = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -633,9 +629,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -643,8 +637,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||||
delays_info = (
|
delays_info = (
|
||||||
params.res_dir
|
params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(delays_info, "w") as f:
|
with open(delays_info, "w") as f:
|
||||||
print("settings\tsymbol-delay", file=f)
|
print("settings\tsymbol-delay", file=f)
|
||||||
|
|||||||
@ -742,17 +742,13 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -762,9 +758,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -386,17 +386,13 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -406,9 +402,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -420,18 +420,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -441,9 +437,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -585,18 +585,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -606,9 +602,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -58,7 +58,6 @@ class Decoder(nn.Module):
|
|||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
num_embeddings=vocab_size,
|
num_embeddings=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
padding_idx=blank_id,
|
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
self.unk_id = unk_id
|
self.unk_id = unk_id
|
||||||
|
|||||||
@ -423,9 +423,7 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
# sort results so we can easily compare the difference between two
|
# sort results so we can easily compare the difference between two
|
||||||
# recognition results
|
# recognition results
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
@ -434,9 +432,7 @@ def save_results(
|
|||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -446,9 +442,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -609,18 +609,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -630,9 +626,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -59,7 +59,6 @@ class Decoder(nn.Module):
|
|||||||
self.embedding = ScaledEmbedding(
|
self.embedding = ScaledEmbedding(
|
||||||
num_embeddings=vocab_size,
|
num_embeddings=vocab_size,
|
||||||
embedding_dim=decoder_dim,
|
embedding_dim=decoder_dim,
|
||||||
padding_idx=blank_id,
|
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
|
|||||||
@ -425,9 +425,7 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
# sort results so we can easily compare the difference between two
|
# sort results so we can easily compare the difference between two
|
||||||
# recognition results
|
# recognition results
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
@ -436,9 +434,7 @@ def save_results(
|
|||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -448,9 +444,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -869,18 +869,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -890,9 +886,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -426,18 +426,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -447,9 +443,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -676,18 +676,14 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
test_set_delays = dict()
|
test_set_delays = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
store_transcripts_and_timestamps(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -698,9 +694,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -708,8 +702,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
|
||||||
delays_info = (
|
delays_info = (
|
||||||
params.res_dir
|
params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(delays_info, "w") as f:
|
with open(delays_info, "w") as f:
|
||||||
print("settings\tsymbol-delay", file=f)
|
print("settings\tsymbol-delay", file=f)
|
||||||
|
|||||||
@ -442,18 +442,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -463,9 +459,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -735,18 +735,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -756,9 +752,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -442,18 +442,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -463,9 +459,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -416,18 +416,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -437,9 +433,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -722,18 +722,14 @@ def save_results(
|
|||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
@ -743,9 +739,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
|||||||
@ -0,0 +1,861 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) greedy search
|
||||||
|
./pruned_transducer_stateless7/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
(2) beam search (not recommended)
|
||||||
|
./pruned_transducer_stateless7/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./pruned_transducer_stateless7/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search (one best)
|
||||||
|
./pruned_transducer_stateless7/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless7/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless7/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(7) fast beam search (with LG)
|
||||||
|
./pruned_transducer_stateless7/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from gigaspeech import GigaSpeechAsrDataModule
|
||||||
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG,
|
||||||
|
fast_beam_search_nbest_oracle,
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless7/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
|
- fast_beam_search_nbest_LG
|
||||||
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=20.0,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding_method is greedy_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def post_processing(
|
||||||
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
|
) -> List[Tuple[str, List[str], List[str]]]:
|
||||||
|
new_results = []
|
||||||
|
for key, ref, hyp in results:
|
||||||
|
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||||
|
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
||||||
|
new_results.append((key, new_ref, new_hyp))
|
||||||
|
return new_results
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
batch: dict,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if greedy_search is used, it would be "greedy_search"
|
||||||
|
If beam search with a beam size of 7 is used, it would be
|
||||||
|
"beam_7"
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
chunk_size=params.decode_chunk_size,
|
||||||
|
left_context=params.left_context,
|
||||||
|
simulate_streaming=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
|
hyps = []
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_LG(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append([word_table[i] for i in hyp])
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
# fmt: off
|
||||||
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
|
# fmt: on
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "beam_search":
|
||||||
|
hyp = beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
return {"greedy_search": hyps}
|
||||||
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
|
key = f"beam_{params.beam}_"
|
||||||
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
|
key += f"max_states_{params.max_states}"
|
||||||
|
if "nbest" in params.decoding_method:
|
||||||
|
key += f"_num_paths_{params.num_paths}_"
|
||||||
|
key += f"nbest_scale_{params.nbest_scale}"
|
||||||
|
if "LG" in params.decoding_method:
|
||||||
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
|
return {key: hyps}
|
||||||
|
else:
|
||||||
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
log_interval = 50
|
||||||
|
else:
|
||||||
|
log_interval = 20
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
word_table=word_table,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
results = post_processing(results)
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = (
|
||||||
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = (
|
||||||
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
This scripts test a libri model with libri BPE
|
||||||
|
on Gigaspeech.
|
||||||
|
"""
|
||||||
|
parser = get_parser()
|
||||||
|
GigaSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_LG",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
|
params.res_dir = params.exp_dir / (params.decoding_method + "_gigaspeech")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
if "nbest" in params.decoding_method:
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
|
if "LG" in params.decoding_method:
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
word_table = lexicon.word_table
|
||||||
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
logging.info(f"Loading {lg_filename}")
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(lg_filename, map_location=device)
|
||||||
|
)
|
||||||
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
word_table = None
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
word_table = None
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
dev_cuts = gigaspeech.dev_cuts()
|
||||||
|
test_cuts = gigaspeech.test_cuts()
|
||||||
|
|
||||||
|
dev_dl = gigaspeech.test_dataloaders(dev_cuts)
|
||||||
|
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["dev", "test"]
|
||||||
|
test_dls = [dev_dl, test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
word_table=word_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user