Merge remote-tracking branch 'dan/master' into deeper-conformer

This commit is contained in:
Fangjun Kuang 2022-05-13 07:41:50 +08:00
commit 994b8a7716
62 changed files with 4610 additions and 867 deletions

View File

@ -9,6 +9,7 @@ per-file-ignores =
egs/tedlium3/ASR/*/conformer.py: E501,
egs/gigaspeech/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless4/*.py: E501,
egs/librispeech/ASR/*/optim.py: E501,
egs/librispeech/ASR/*/scaling.py: E501,

View File

@ -0,0 +1,17 @@
#!/usr/bin/env bash
# This script computes fbank features for the test-clean and test-other datasets.
# The computed features are saved to ~/tmp/fbank-libri and are
# cached for later runs
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
mkdir ~/tmp/fbank-libri
cd egs/librispeech/ASR
mkdir -p data
cd data
[ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank
cd ..
./local/compute_fbank_librispeech.py
ls -lh data/fbank/

View File

@ -0,0 +1,23 @@
#!/usr/bin/env bash
# This script downloads the test-clean and test-other datasets
# of LibriSpeech and unzip them to the folder ~/tmp/download,
# which is cached by GitHub actions for later runs.
#
# You will find directories ~/tmp/download/LibriSpeech after running
# this script.
mkdir ~/tmp/download
cd egs/librispeech/ASR
ln -s ~/tmp/download .
cd download
wget -q --no-check-certificate https://www.openslr.org/resources/12/test-clean.tar.gz
tar xf test-clean.tar.gz
rm test-clean.tar.gz
wget -q --no-check-certificate https://www.openslr.org/resources/12/test-other.tar.gz
tar xf test-other.tar.gz
rm test-other.tar.gz
pwd
ls -lh
ls -lh LibriSpeech

13
.github/scripts/install-kaldifeat.sh vendored Executable file
View File

@ -0,0 +1,13 @@
#!/usr/bin/env bash
# This script installs kaldifeat into the directory ~/tmp/kaldifeat
# which is cached by GitHub actions for later runs.
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat

View File

@ -0,0 +1,11 @@
#!/usr/bin/env bash
# This script assumes that test-clean and test-other are downloaded
# to egs/librispeech/ASR/download/LibriSpeech and generates manifest
# files in egs/librispeech/ASR/data/manifests
cd egs/librispeech/ASR
[ ! -e download ] && ln -s ~/tmp/download .
mkdir -p data/manifests
lhotse prepare librispeech -j 2 -p test-clean -p test-other ./download/LibriSpeech data/manifests
ls -lh data/manifests

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
for method in fast_beam_search modified_beam_search beam_search; do
log "$method"
./pruned_transducer_stateless/pretrained.py \
@ -45,3 +45,32 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh pruned_transducer_stateless/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless/exp
done
rm pruned_transducer_stateless/exp/*.pt
fi

View File

@ -49,3 +49,32 @@ for method in modified_beam_search beam_search fast_beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh pruned_transducer_stateless2/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless2/exp
done
rm pruned_transducer_stateless2/exp/*.pt
fi

View File

@ -49,3 +49,32 @@ for method in modified_beam_search beam_search fast_beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless3/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh pruned_transducer_stateless3/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless3/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless3/exp
done
rm pruned_transducer_stateless3/exp/*.pt
fi

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
for method in fast_beam_search modified_beam_search beam_search; do
log "$method"
./transducer_stateless2/pretrained.py \
@ -45,3 +45,32 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh transducer_stateless2/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./transducer_stateless2/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir transducer_stateless2/exp
done
rm transducer_stateless2/exp/*.pt
fi

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./transducer_stateless_multi_datasets/pretrained.py \
@ -45,3 +45,32 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p transducer_stateless_multi_datasets/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh transducer_stateless_multi_datasets/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./transducer_stateless_multi_datasets/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir transducer_stateless_multi_datasets/exp
done
rm transducer_stateless_multi_datasets/exp/*.pt
fi

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./transducer_stateless_multi_datasets/pretrained.py \
@ -45,3 +45,32 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p transducer_stateless_multi_datasets/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh transducer_stateless_multi_datasets/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./transducer_stateless_multi_datasets/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir transducer_stateless_multi_datasets/exp
done
rm transducer_stateless_multi_datasets/exp/*.pt
fi

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
for method in fast_beam_search modified_beam_search beam_search; do
log "$method"
./transducer_stateless/pretrained.py \
@ -46,15 +46,31 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0002.wav
done
for method in modified_beam_search beam_search; do
log "$method"
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p transducer_stateless/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
./transducer_stateless_multi_datasets/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
done
ls -lh data
ls -lh transducer_stateless/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./transducer_stateless/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir transducer_stateless/exp
done
rm transducer_stateless/exp/*.pt
fi

View File

@ -24,9 +24,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_2022_03_12:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -63,20 +72,82 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
- name: Display decoding results for pruned_transducer_stateless
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./pruned_transducer_stateless/exp
cd pruned_transducer_stateless
echo "results for pruned_transducer_stateless"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/

View File

@ -24,9 +24,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_2022_04_29:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -63,18 +72,51 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
@ -83,3 +125,55 @@ jobs:
.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
- name: Display decoding results for pruned_transducer_stateless2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR
tree pruned_transducer_stateless2/exp
cd pruned_transducer_stateless2/exp
echo "===greedy search==="
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Display decoding results for pruned_transducer_stateless3
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR
tree pruned_transducer_stateless3/exp
cd pruned_transducer_stateless3/exp
echo "===greedy search==="
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
- name: Upload decoding results for pruned_transducer_stateless3
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/

View File

@ -24,9 +24,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_librispeech_2022_04_19:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -63,20 +72,82 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
- name: Display decoding results
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./transducer_stateless2/exp
cd transducer_stateless2
echo "results for transducer_stateless2"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified_beam_search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
path: egs/librispeech/ASR/transducer_stateless2/exp/

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -23,9 +23,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -62,20 +71,82 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
- name: Display decoding results for transducer_stateless_multi_datasets
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./transducer_stateless_multi_datasets/exp
cd transducer_stateless_multi_datasets
echo "results for transducer_stateless_multi_datasets"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless_multi_datasets
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/

View File

@ -23,9 +23,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -62,20 +71,82 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
- name: Display decoding results for transducer_stateless_multi_datasets
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./transducer_stateless_multi_datasets/exp
cd transducer_stateless_multi_datasets
echo "results for transducer_stateless_multi_datasets"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless_multi_datasets
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -62,14 +62,7 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-trandsucer-stateless
name: run-pre-trained-transducer-stateless
on:
push:
@ -23,9 +23,18 @@ on:
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_pre_trained_transducer_stateless:
if: github.event.label.name == 'ready' || github.event_name == 'push'
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -62,20 +71,82 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
.github/scripts/install-kaldifeat.sh
- name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-pre-trained-transducer-stateless.sh
- name: Display decoding results for transducer_stateless
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./transducer_stateless/exp
cd transducer_stateless
echo "results for transducer_stateless"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless-2022-02-07
path: egs/librispeech/ASR/transducer_stateless/exp/

View File

@ -62,13 +62,6 @@ jobs:
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat
cd kaldifeat
mkdir build
cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
make -j2 _kaldifeat
- name: Inference with pre-trained model

View File

@ -110,7 +110,9 @@ class Conformer(Transformer):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -22,7 +23,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
@ -33,20 +34,19 @@ To use the generated file with `transducer_stateless/decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
cd /path/to/egs/aishell/ASR
./transducer_stateless/decode.py \
--exp-dir ./transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
@ -56,6 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
@ -91,10 +92,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
@ -194,12 +195,10 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
lexicon = Lexicon(params.lang_dir)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)

View File

@ -19,49 +19,62 @@
Usage:
(1) greedy search
./transducer_stateless_modified-2/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method greedy_search
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_stateless_modified/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(2) beam search (not recommended)
./transducer_stateless_modified-2/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_stateless_modified-2/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_stateless_modified-2/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method fast_beam_search \
--beam-size 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from aishell import AIShell
from asr_datamodule import AsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -114,6 +127,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -121,8 +135,35 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --decoding-method is beam_search "
"and modified_beam_search",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
@ -132,84 +173,24 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
token_table: k2.SymbolTable,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -230,8 +211,8 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
token_table:
It maps token ID to a string.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -249,44 +230,80 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_tokens.append(hyp)
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -297,6 +314,11 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
token_table:
It maps a token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -312,9 +334,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -323,7 +345,8 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
token_table=token_table,
decoding_graph=decoding_graph,
batch=batch,
)
@ -358,6 +381,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
@ -408,13 +432,21 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -456,6 +488,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -472,7 +509,8 @@ def main():
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
token_table=lexicon.token_table,
decoding_graph=decoding_graph,
)
save_results(
@ -484,8 +522,5 @@ def main():
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -19,7 +19,7 @@
"""
Usage:
# greedy search
(1) greedy search
./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
@ -27,7 +27,7 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav
# beam search
(2) beam search
./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
@ -36,7 +36,7 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav
# modified beam search
(3) modified beam search
./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
@ -45,6 +45,14 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
@ -53,11 +61,13 @@ import math
from pathlib import Path
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -97,6 +107,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -121,7 +132,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@ -134,11 +171,10 @@ def get_parser():
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
default=1,
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
return parser
return parser
@ -225,20 +261,37 @@ def main():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
logging.info(f"Using {params.method}")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
for i in range(encoder_out.size(0)):
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on

View File

@ -19,48 +19,63 @@
Usage:
(1) greedy search
./transducer_stateless_modified/decode.py \
--epoch 64 \
--avg 33 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method greedy_search
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./transducer_stateless_modified/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_stateless_modified/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_stateless_modified/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -113,6 +128,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -120,7 +136,35 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --decoding-method is beam_search",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
@ -130,84 +174,24 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
token_table: k2.SymbolTable,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -228,8 +212,11 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
token_table:
It maps token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -247,44 +234,80 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_tokens.append(hyp)
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -295,6 +318,11 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
token_table:
It maps a token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -310,9 +338,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -321,7 +349,8 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
token_table=token_table,
decoding_graph=decoding_graph,
batch=batch,
)
@ -356,6 +385,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
@ -406,13 +436,21 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -452,6 +490,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -467,7 +510,8 @@ def main():
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
token_table=lexicon.token_table,
decoding_graph=decoding_graph,
)
save_results(
@ -479,8 +523,5 @@ def main():
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -19,7 +19,7 @@
"""
Usage:
# greedy search
(1) greedy search
./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
@ -27,7 +27,7 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav
# beam search
(2) beam search
./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
@ -36,7 +36,7 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav
# modified beam search
(3) modified beam search
./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
@ -45,6 +45,14 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
@ -53,11 +61,13 @@ import math
from pathlib import Path
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -97,6 +107,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -121,7 +132,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@ -134,11 +171,10 @@ def get_parser():
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
default=1,
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
return parser
return parser
@ -225,20 +261,37 @@ def main():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
logging.info(f"Using {params.method}")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
for i in range(encoder_out.size(0)):
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on

View File

@ -177,8 +177,8 @@ def post_processing(
) -> List[Tuple[List[str], List[str]]]:
new_results = []
for ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref))
new_hyp = asr_text_post_processing(" ".join(hyp))
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((new_ref, new_hyp))
return new_results

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
@ -26,6 +27,149 @@ from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
def fast_beam_search_one_best(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.
This is the best result we can achieve for any nbest based rescoring
methods.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
ref_texts:
A list-of-list of integers containing the reference transcripts.
If the decoding_graph is a trivial_graph, the integer ID is the
BPE token ID.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
@ -34,8 +178,7 @@ def fast_beam_search(
beam: float,
max_states: int,
max_contexts: int,
use_max: bool = False,
) -> List[List[int]]:
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
Args:
@ -54,11 +197,10 @@ def fast_beam_search(
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns:
Return the decoded result.
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
lattice is actually an acceptor.
"""
assert encoder_out.ndim == 3
@ -91,7 +233,7 @@ def fast_beam_search(
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).long()
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
# in some old versions of pytorch, the type of index requires
# to be LongTensor. In the newest version of pytorch, the type
# of index can be IntTensor or LongTensor. For supporting the
@ -108,67 +250,7 @@ def fast_beam_search(
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
if use_max:
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
else:
num_paths = 200
use_double_scores = True
nbest_scale = 0.8
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores, log_semiring=True
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
return lattice
def greedy_search(
@ -192,10 +274,10 @@ def greedy_search(
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64
@ -229,7 +311,7 @@ def greedy_search(
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y != blank_id and y != unk_id:
if y not in (blank_id, unk_id):
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
@ -248,7 +330,9 @@ def greedy_search(
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -256,6 +340,9 @@ def greedy_search_batch(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
@ -263,28 +350,48 @@ def greedy_search_batch(
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
device = next(model.parameters()).device
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# decoder_out: (N, 1, decoder_out_dim)
encoder_out = packed_encoder_out.data
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits'shape (batch_size, 1, 1, vocab_size)
@ -293,12 +400,12 @@ def greedy_search_batch(
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id and v != unk_id:
if v not in (blank_id, unk_id):
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
@ -306,7 +413,12 @@ def greedy_search_batch(
)
decoder_out = model.decoder(decoder_input, need_pad=False)
ans = [h[context_size:] for h in hyps]
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@ -471,6 +583,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[List[int]]:
@ -481,6 +594,9 @@ def modified_beam_search(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
use_max:
@ -491,16 +607,27 @@ def modified_beam_search(
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
@ -509,9 +636,20 @@ def modified_beam_search(
use_max=use_max,
)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
encoder_out = packed_encoder_out.data
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device)
@ -565,8 +703,10 @@ def modified_beam_search(
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
@ -574,15 +714,21 @@ def modified_beam_search(
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id and new_token != unk_id:
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@ -619,10 +765,10 @@ def _deprecated_modified_beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
T = encoder_out.size(1)
@ -679,14 +825,16 @@ def _deprecated_modified_beam_search(
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id and new_token != unk_id:
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
@ -727,10 +875,10 @@ def beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size,
@ -813,7 +961,7 @@ def beam_search(
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id or i == unk_id:
if i in (blank_id, unk_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v

View File

@ -19,53 +19,53 @@
Usage:
(1) greedy search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) fast beam search using LG
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--use-LG True \
--use-max False \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 8 \
--max-contexts 8 \
--max-states 64
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--use-LG True \
--use-max False \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 8 \
--max-contexts 8 \
--max-states 64
"""
@ -82,7 +82,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -307,7 +307,7 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -315,7 +315,6 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
use_max=params.use_max,
)
if params.use_LG:
for hyp in hyp_tokens:
@ -330,6 +329,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -337,6 +337,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
use_max=params.use_max,
)
@ -421,9 +422,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):

View File

@ -25,7 +25,7 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav \
(1) beam search
(2) beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
@ -34,6 +34,24 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav \
(3) modified beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
(4) fast beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by
@ -46,12 +64,14 @@ import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -77,9 +97,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="""Path to bpe.model.""",
)
parser.add_argument(
@ -90,6 +108,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -114,7 +133,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@ -230,10 +275,25 @@ def main():
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "modified_beam_search":
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
@ -243,6 +303,7 @@ def main():
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())

View File

@ -276,7 +276,7 @@ def greedy_search(
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64
@ -335,7 +335,9 @@ def greedy_search(
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -343,6 +345,9 @@ def greedy_search_batch(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
@ -350,31 +355,49 @@ def greedy_search_batch(
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
device = next(model.parameters()).device
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
# decoder_out: (N, 1, decoder_out_dim)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
@ -390,7 +413,7 @@ def greedy_search_batch(
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
@ -399,7 +422,12 @@ def greedy_search_batch(
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
ans = [h[context_size:] for h in hyps]
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@ -557,6 +585,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -566,6 +595,9 @@ def modified_beam_search(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
Returns:
@ -573,16 +605,27 @@ def modified_beam_search(
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
@ -590,11 +633,20 @@ def modified_beam_search(
)
)
encoder_out = model.joiner.encoder_proj(encoder_out)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device)
@ -668,8 +720,14 @@ def modified_beam_search(
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@ -705,7 +763,7 @@ def _deprecated_modified_beam_search(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
T = encoder_out.size(1)
@ -813,7 +871,7 @@ def beam_search(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size,

View File

@ -22,15 +22,15 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
@ -270,6 +270,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -277,6 +278,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
@ -356,9 +358,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):

View File

@ -156,15 +156,16 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need to "
"be changed.",
)
parser.add_argument(
"--lr-batches",
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this.""",
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)
parser.add_argument(
@ -670,25 +671,29 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5:
return
@ -933,6 +938,38 @@ def run(rank, world_size, args):
cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
@ -964,7 +1001,7 @@ def scan_pessimistic_batches_for_oom(
loss.backward()
optimizer.step()
optimizer.zero_grad()
except RuntimeError as e:
except Exception as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
@ -973,6 +1010,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, sp=sp)
raise

View File

@ -22,15 +22,15 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./pruned_transducer_stateless3/decode-giga.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 1500 \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
@ -69,7 +69,8 @@ import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -100,27 +101,28 @@ def get_parser():
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
"'--epoch' and '--iter'",
)
parser.add_argument(
@ -146,6 +148,7 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest_oracle
""",
)
@ -165,7 +168,8 @@ def get_parser():
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
Used only when --decoding-method is
fast_beam_search or fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -173,7 +177,7 @@ def get_parser():
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search or fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -181,7 +185,7 @@ def get_parser():
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search or fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -199,12 +203,29 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for computed nbest oracle WER
when the decoding method is fast_beam_search_nbest_oracle.
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding_method is fast_beam_search_nbest_oracle.
""",
)
return parser
def post_processing(
results: List[Tuple[List[List[str]], List[List[str]]]],
) -> List[Tuple[List[List[str]], List[List[str]]]]:
results: List[Tuple[List[str], List[str]]],
) -> List[Tuple[List[str], List[str]]]:
new_results = []
for ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
@ -243,7 +264,8 @@ def decode_one_batch(
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is
fast_beam_search or fast_beam_search_nbest_oracle.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -264,7 +286,7 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -275,6 +297,21 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -328,6 +365,16 @@ def decode_one_batch(
f"max_states_{params.max_states}"
): hyps
}
elif params.decoding_method == "fast_beam_search_nbest_oracle":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -368,9 +415,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -463,17 +510,30 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / "giga" / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif params.decoding_method == "fast_beam_search_nbest_oracle":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -490,8 +550,9 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.unk_id()
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -499,8 +560,20 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
@ -519,13 +592,17 @@ def main():
model.to(device)
model.eval()
model.device = device
model.unk_id = params.unk_id
# In beam_search.py, we are using model.decoder() and model.joiner(),
# so we have to switch to the branch for the GigaSpeech dataset.
model.decoder = model.decoder_giga
model.joiner = model.joiner_giga
if params.decoding_method == "fast_beam_search":
if params.decoding_method in (
"fast_beam_search",
"fast_beam_search_nbest_oracle",
):
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None

View File

@ -22,15 +22,15 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 1500 \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
@ -307,6 +307,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -314,6 +315,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
@ -403,9 +405,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/__init__.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/conformer.py

View File

@ -0,0 +1,633 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless4/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/export.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -22,6 +22,235 @@ import k2
import torch
from model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
def fast_beam_search_one_best(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.
This is the best result we can achieve for any nbest based rescoring
methods.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
ref_texts:
A list-of-list of integers containing the reference transcripts.
If the decoding_graph is a trivial_graph, the integer ID is the
BPE token ID.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
lattice is actually an acceptor.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
encoder_out_len = torch.ones(1, dtype=torch.int32)
decoder_out_len = torch.ones(1, dtype=torch.int32)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
) # (N, vocab_size)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
return lattice
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
@ -104,7 +333,9 @@ def greedy_search(
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -112,6 +343,9 @@ def greedy_search_batch(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
@ -119,32 +353,54 @@ def greedy_search_batch(
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
device = next(model.parameters()).device
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
# decoder_out: (N, 1, decoder_out_dim)
encoder_out_len = torch.ones(batch_size, dtype=torch.int32)
decoder_out_len = torch.ones(batch_size, dtype=torch.int32)
encoder_out_len = torch.ones(1, dtype=torch.int32)
decoder_out_len = torch.ones(1, dtype=torch.int32)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
encoder_out = packed_encoder_out.data
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
current_encoder_out,
decoder_out,
encoder_out_len.expand(batch_size),
decoder_out_len.expand(batch_size),
) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
@ -157,7 +413,7 @@ def greedy_search_batch(
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
@ -168,7 +424,12 @@ def greedy_search_batch(
need_pad=False,
) # (batch_size, 1, decoder_out_dim)
ans = [h[context_size:] for h in hyps]
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@ -415,6 +676,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
@ -424,6 +686,9 @@ def modified_beam_search(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
Returns:
@ -431,15 +696,26 @@ def modified_beam_search(
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
@ -449,9 +725,20 @@ def modified_beam_search(
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
encoder_out = packed_encoder_out.data
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1)
# current_encoder_out's shape is: (batch_size, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device)
@ -524,8 +811,14 @@ def modified_beam_search(
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans

View File

@ -19,29 +19,40 @@
Usage:
(1) greedy search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
@ -49,14 +60,16 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -115,6 +128,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -122,8 +136,35 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
@ -149,6 +190,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -171,6 +213,9 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -188,24 +233,44 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyp_list: List[List[int]] = []
if (
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_list = greedy_search_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -226,14 +291,20 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
@ -241,6 +312,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -253,6 +325,9 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -268,9 +343,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -280,6 +355,7 @@ def decode_dataset(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
@ -360,13 +436,21 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -408,6 +492,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +517,7 @@ def main():
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
padding_idx=blank_id,
)
self.blank_id = blank_id
self.vocab_size = vocab_size
assert context_size >= 1, context_size
self.context_size = context_size

View File

@ -19,30 +19,39 @@ Usage:
(1) greedy search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
@ -56,12 +65,14 @@ import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -87,9 +98,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="""Path to bpe.model.""",
)
parser.add_argument(
@ -100,6 +109,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -124,7 +134,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@ -241,15 +277,28 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:

View File

@ -22,15 +22,15 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless2/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./transducer_stateless2/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless2/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
@ -39,9 +39,20 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless2/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_stateless2/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
@ -49,14 +60,16 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -115,6 +128,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -122,8 +136,35 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
@ -149,6 +190,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -171,6 +213,9 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -188,24 +233,44 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyp_list: List[List[int]] = []
if (
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_list = greedy_search_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -226,14 +291,20 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
@ -241,6 +312,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -253,6 +325,9 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -268,9 +343,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -280,6 +355,7 @@ def decode_dataset(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
@ -360,13 +436,21 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -408,6 +492,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +517,7 @@ def main():
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(

View File

@ -19,30 +19,39 @@ Usage:
(1) greedy search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./transducer_stateless2/exp/epoch-xx.pt`.
@ -56,12 +65,14 @@ import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -87,9 +98,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="""Path to bpe.model.""",
)
parser.add_argument(
@ -100,6 +109,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -124,7 +134,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@ -241,15 +277,28 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:

View File

@ -22,17 +22,37 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 100 \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
@ -40,14 +60,16 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -107,6 +129,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -114,8 +137,35 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
@ -141,6 +191,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -163,6 +214,9 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -180,24 +234,44 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyp_list = []
batch_size = encoder_out.size(0)
if (
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_list = greedy_search_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -218,14 +292,20 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(sp.decode(hyp).split())
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
@ -233,6 +313,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -245,6 +326,9 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -260,9 +344,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
log_interval = 50
else:
log_interval = 2
log_interval = 10
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -272,6 +356,7 @@ def decode_dataset(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
@ -352,13 +437,21 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -402,6 +495,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -423,6 +521,7 @@ def main():
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(

View File

@ -44,6 +44,15 @@ Usage:
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./transducer_stateless_multi_datasets/exp/epoch-xx.pt`.
Note: ./transducer_stateless_multi_datasets/exp/pretrained.pt is generated by
@ -56,12 +65,14 @@ import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -87,9 +98,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="""Path to bpe.model.""",
)
parser.add_argument(
@ -100,6 +109,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -124,7 +134,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@ -241,18 +277,30 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off

View File

@ -69,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import TedLiumAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -237,7 +237,7 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -255,6 +255,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -262,6 +263,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):

View File

@ -72,23 +72,16 @@ import k2
import kaldifeat
import sentencepiece as spm
import torch
import torch.nn as nn
import torchaudio
from beam_search import (
beam_search,
fast_beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
from train import get_params, get_transducer_model
def get_parser():
@ -185,76 +178,16 @@ def get_parser():
""",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
@ -354,7 +287,7 @@ def main():
logging.info(msg)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -372,6 +305,7 @@ def main():
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -379,6 +313,7 @@ def main():
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
@ -25,6 +26,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@ -37,6 +39,7 @@ LRSchedulerType = object
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
@ -51,6 +54,8 @@ def save_checkpoint(
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
model_avg:
The stored model averaged from the start of training.
params:
User defined parameters, e.g., epoch, loss.
optimizer:
@ -80,6 +85,9 @@ def save_checkpoint(
"sampler": sampler.state_dict() if sampler is not None else None,
}
if model_avg is not None:
checkpoint["model_avg"] = model_avg.state_dict()
if params:
for k, v in params.items():
assert k not in checkpoint
@ -91,6 +99,7 @@ def save_checkpoint(
def load_checkpoint(
filename: Path,
model: nn.Module,
model_avg: Optional[nn.Module] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
@ -118,6 +127,11 @@ def load_checkpoint(
checkpoint.pop("model")
if model_avg is not None and "model_avg" in checkpoint:
logging.info("Loading averaged model")
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
checkpoint.pop("model_avg")
def load(name, obj):
s = checkpoint.get(name, None)
if obj and s:
@ -181,6 +195,7 @@ def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
@ -201,6 +216,8 @@ def save_checkpoint_with_global_batch_idx(
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
model_avg:
The stored model averaged from the start of training.
params:
A dict of training configurations to be saved.
optimizer:
@ -223,6 +240,7 @@ def save_checkpoint_with_global_batch_idx(
save_checkpoint(
filename=filename,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
@ -327,3 +345,129 @@ def remove_checkpoints(
to_remove = checkpoints[topk:]
for c in to_remove:
os.remove(c)
def update_averaged_model(
params: Dict[str, Tensor],
model_cur: Union[nn.Module, DDP],
model_avg: nn.Module,
) -> None:
"""Update the averaged model:
model_avg = model_cur * (average_period / batch_idx_train)
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)
Args:
params:
User defined parameters, e.g., epoch, loss.
model_cur:
The current model.
model_avg:
The averaged model to be updated.
"""
weight_cur = params.average_period / params.batch_idx_train
weight_avg = 1 - weight_cur
if isinstance(model_cur, DDP):
model_cur = model_cur.module
cur = model_cur.state_dict()
avg = model_avg.state_dict()
average_state_dict(
state_dict_1=avg,
state_dict_2=cur,
weight_1=weight_avg,
weight_2=weight_cur,
)
def average_checkpoints_with_averaged_model(
filename_start: str,
filename_end: str,
device: torch.device = torch.device("cpu"),
) -> Dict[str, Tensor]:
"""Average model parameters over the range with given
start model (excluded) and end model.
Let start = batch_idx_train of model-start;
end = batch_idx_train of model-end;
interval = end - start.
Then the average model over range from start (excluded) to end is
(1) avg = (model_end * end - model_start * start) / interval.
It can be written as
(2) avg = model_end * weight_end + model_start * weight_start,
where weight_end = end / interval,
weight_start = -start / interval = 1 - weight_end.
Since the terms `weight_end` and `weight_start` would be large
if the model has been trained for lots of batches, which would cause
overflow when multiplying the model parameters.
To avoid this, we rewrite (2) as:
(3) avg = (model_end + model_start * (weight_start / weight_end))
* weight_end
The model index could be epoch number or iteration number.
Args:
filename_start:
Checkpoint filename of the start model. We assume it
is saved by :func:`save_checkpoint`.
filename_end:
Checkpoint filename of the end model. We assume it
is saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
"""
state_dict_start = torch.load(filename_start, map_location=device)
state_dict_end = torch.load(filename_end, map_location=device)
batch_idx_train_start = state_dict_start["batch_idx_train"]
batch_idx_train_end = state_dict_end["batch_idx_train"]
interval = batch_idx_train_end - batch_idx_train_start
assert interval > 0, interval
weight_end = batch_idx_train_end / interval
weight_start = 1 - weight_end
model_end = state_dict_end["model_avg"]
model_start = state_dict_start["model_avg"]
avg = model_end
# scale the weight to avoid overflow
average_state_dict(
state_dict_1=avg,
state_dict_2=model_start,
weight_1=1.0,
weight_2=weight_start / weight_end,
scaling_factor=weight_end,
)
return avg
def average_state_dict(
state_dict_1: Dict[str, Tensor],
state_dict_2: Dict[str, Tensor],
weight_1: float,
weight_2: float,
scaling_factor: float = 1.0,
) -> Dict[str, Tensor]:
"""Average two state_dict with given weights:
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
* scaling_factor
It is an in-place operation on state_dict_1 itself.
"""
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in state_dict_1.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for k in uniqued_names:
state_dict_1[k] *= weight_1
state_dict_1[k] += (
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
)
state_dict_1[k] *= scaling_factor

View File

@ -95,7 +95,7 @@ def get_env_info() -> Dict[str, Any]:
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-version": torch.__version__,
"torch-version": str(torch.__version__),
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],