Ignore padding frames during RNN-T decoding. (#358)

* Ignore padding frames during RNN-T decoding.

* Fix outdated decoding code.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2022-05-13 07:39:14 +08:00 committed by GitHub
parent bc284e88e6
commit aeb8986e35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 2207 additions and 766 deletions

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
done done
for method in modified_beam_search beam_search; do for method in fast_beam_search modified_beam_search beam_search; do
log "$method" log "$method"
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
@ -47,7 +47,8 @@ for method in modified_beam_search beam_search; do
done done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless/exp mkdir -p pruned_transducer_stateless/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/ ln -s $PWD/$repo/data/lang_bpe_500 data/
@ -58,9 +59,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
log "Decoding test-clean and test-other" log "Decoding test-clean and test-other"
# use a small value for decoding with CPU # use a small value for decoding with CPU
max_duration=50 max_duration=100
for method in greedy_search fast_beam_search; do for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method" log "Decoding with $method"
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \

View File

@ -51,7 +51,8 @@ for method in modified_beam_search beam_search fast_beam_search; do
done done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless2/exp mkdir -p pruned_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/ ln -s $PWD/$repo/data/lang_bpe_500 data/
@ -62,9 +63,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
log "Decoding test-clean and test-other" log "Decoding test-clean and test-other"
# use a small value for decoding with CPU # use a small value for decoding with CPU
max_duration=50 max_duration=100
for method in greedy_search fast_beam_search; do for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method" log "Decoding with $method"
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \

View File

@ -51,7 +51,8 @@ for method in modified_beam_search beam_search fast_beam_search; do
done done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless3/exp mkdir -p pruned_transducer_stateless3/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/ ln -s $PWD/$repo/data/lang_bpe_500 data/
@ -62,9 +63,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
log "Decoding test-clean and test-other" log "Decoding test-clean and test-other"
# use a small value for decoding with CPU # use a small value for decoding with CPU
max_duration=50 max_duration=100
for method in greedy_search fast_beam_search; do for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method" log "Decoding with $method"
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless3/decode.py \

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
done done
for method in modified_beam_search beam_search; do for method in fast_beam_search modified_beam_search beam_search; do
log "$method" log "$method"
./transducer_stateless2/pretrained.py \ ./transducer_stateless2/pretrained.py \
@ -47,7 +47,8 @@ for method in modified_beam_search beam_search; do
done done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p transducer_stateless2/exp mkdir -p transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/ ln -s $PWD/$repo/data/lang_bpe_500 data/
@ -58,9 +59,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
log "Decoding test-clean and test-other" log "Decoding test-clean and test-other"
# use a small value for decoding with CPU # use a small value for decoding with CPU
max_duration=50 max_duration=100
for method in greedy_search modified_beam_search; do for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method" log "Decoding with $method"
./transducer_stateless2/decode.py \ ./transducer_stateless2/decode.py \

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
done done
for method in modified_beam_search beam_search; do for method in modified_beam_search beam_search fast_beam_search; do
log "$method" log "$method"
./transducer_stateless_multi_datasets/pretrained.py \ ./transducer_stateless_multi_datasets/pretrained.py \
@ -45,3 +45,32 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
done done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p transducer_stateless_multi_datasets/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh transducer_stateless_multi_datasets/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./transducer_stateless_multi_datasets/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir transducer_stateless_multi_datasets/exp
done
rm transducer_stateless_multi_datasets/exp/*.pt
fi

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
done done
for method in modified_beam_search beam_search; do for method in modified_beam_search beam_search fast_beam_search; do
log "$method" log "$method"
./transducer_stateless_multi_datasets/pretrained.py \ ./transducer_stateless_multi_datasets/pretrained.py \
@ -45,3 +45,32 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
done done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p transducer_stateless_multi_datasets/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
ls -lh transducer_stateless_multi_datasets/exp
log "Decoding test-clean and test-other"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./transducer_stateless_multi_datasets/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir transducer_stateless_multi_datasets/exp
done
rm transducer_stateless_multi_datasets/exp/*.pt
fi

View File

@ -33,7 +33,7 @@ for sym in 1 2 3; do
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
done done
for method in modified_beam_search beam_search; do for method in fast_beam_search modified_beam_search beam_search; do
log "$method" log "$method"
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
@ -46,15 +46,31 @@ for method in modified_beam_search beam_search; do
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
done done
for method in modified_beam_search beam_search; do echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
log "$method" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p transducer_stateless/exp
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
./transducer_stateless_multi_datasets/pretrained.py \ ls -lh data
--method $method \ ls -lh transducer_stateless/exp
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ log "Decoding test-clean and test-other"
--bpe-model $repo/data/lang_bpe_500/bpe.model \
$repo/test_wavs/1089-134686-0001.wav \ # use a small value for decoding with CPU
$repo/test_wavs/1221-135766-0001.wav \ max_duration=100
$repo/test_wavs/1221-135766-0002.wav
done for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./transducer_stateless/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir transducer_stateless/exp
done
rm transducer_stateless/exp/*.pt
fi

View File

@ -35,7 +35,7 @@ on:
jobs: jobs:
run_librispeech_2022_03_12: run_librispeech_2022_03_12:
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -107,11 +107,11 @@ jobs:
run: | run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model - name: Inference with pre-trained model
shell: bash shell: bash
env: env:
GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: | run: |
mkdir -p egs/librispeech/ASR/data mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
@ -124,8 +124,8 @@ jobs:
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
- name: Display decoding results - name: Display decoding results for pruned_transducer_stateless
if: github.event_name == 'schedule' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash shell: bash
run: | run: |
cd egs/librispeech/ASR/ cd egs/librispeech/ASR/
@ -141,9 +141,13 @@ jobs:
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless - name: Upload decoding results for pruned_transducer_stateless
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with: with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12 name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/ path: egs/librispeech/ASR/pruned_transducer_stateless/exp/

View File

@ -35,7 +35,7 @@ on:
jobs: jobs:
run_librispeech_2022_04_29: run_librispeech_2022_04_29:
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -111,6 +111,7 @@ jobs:
shell: bash shell: bash
env: env:
GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: | run: |
mkdir -p egs/librispeech/ASR/data mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
@ -125,44 +126,54 @@ jobs:
.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
- name: Display decoding results - name: Display decoding results for pruned_transducer_stateless2
if: github.event_name == 'schedule' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash shell: bash
run: | run: |
cd egs/librispeech/ASR cd egs/librispeech/ASR
tree pruned_transducer_stateless2/exp tree pruned_transducer_stateless2/exp
cd pruned_transducer_stateless2 cd pruned_transducer_stateless2/exp
echo "results for pruned_transducer_stateless2"
echo "===greedy search===" echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search===" echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
cd ../ echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Display decoding results for pruned_transducer_stateless3
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR
tree pruned_transducer_stateless3/exp tree pruned_transducer_stateless3/exp
cd pruned_transducer_stateless3 cd pruned_transducer_stateless3/exp
echo "results for pruned_transducer_stateless3"
echo "===greedy search===" echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search===" echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless2 - name: Upload decoding results for pruned_transducer_stateless2
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with: with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29 name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/ path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
- name: Upload decoding results for pruned_transducer_stateless3 - name: Upload decoding results for pruned_transducer_stateless3
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with: with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29 name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/ path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/

View File

@ -35,7 +35,7 @@ on:
jobs: jobs:
run_librispeech_2022_04_19: run_librispeech_2022_04_19:
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -111,6 +111,7 @@ jobs:
shell: bash shell: bash
env: env:
GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: | run: |
mkdir -p egs/librispeech/ASR/data mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
@ -124,7 +125,7 @@ jobs:
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
- name: Display decoding results - name: Display decoding results
if: github.event_name == 'schedule' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash shell: bash
run: | run: |
cd egs/librispeech/ASR/ cd egs/librispeech/ASR/
@ -136,13 +137,17 @@ jobs:
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified_beam_search===" echo "===modified_beam_search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless2 - name: Upload decoding results for transducer_stateless2
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with: with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19 name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
path: egs/librispeech/ASR/transducer_stateless2/exp/ path: egs/librispeech/ASR/transducer_stateless2/exp/

View File

@ -23,9 +23,18 @@ on:
pull_request: pull_request:
types: [labeled] types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs: jobs:
run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
if: github.event.label.name == 'ready' || github.event_name == 'push' if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -64,11 +73,80 @@ jobs:
run: | run: |
.github/scripts/install-kaldifeat.sh .github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model - name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash shell: bash
run: | run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
- name: Display decoding results for transducer_stateless_multi_datasets
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./transducer_stateless_multi_datasets/exp
cd transducer_stateless_multi_datasets
echo "results for transducer_stateless_multi_datasets"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless_multi_datasets
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/

View File

@ -23,9 +23,18 @@ on:
pull_request: pull_request:
types: [labeled] types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs: jobs:
run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
if: github.event.label.name == 'ready' || github.event_name == 'push' if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -64,11 +73,80 @@ jobs:
run: | run: |
.github/scripts/install-kaldifeat.sh .github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model - name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash shell: bash
run: | run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
- name: Display decoding results for transducer_stateless_multi_datasets
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./transducer_stateless_multi_datasets/exp
cd transducer_stateless_multi_datasets
echo "results for transducer_stateless_multi_datasets"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless_multi_datasets
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
name: run-pre-trained-trandsucer-stateless name: run-pre-trained-transducer-stateless
on: on:
push: push:
@ -23,9 +23,18 @@ on:
pull_request: pull_request:
types: [labeled] types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs: jobs:
run_pre_trained_transducer_stateless: run_pre_trained_transducer_stateless:
if: github.event.label.name == 'ready' || github.event_name == 'push' if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -64,11 +73,80 @@ jobs:
run: | run: |
.github/scripts/install-kaldifeat.sh .github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model - name: Cache LibriSpeech test-clean and test-other datasets
id: libri-test-clean-and-test-other-data
uses: actions/cache@v2
with:
path: |
~/tmp/download
key: cache-libri-test-clean-and-test-other
- name: Download LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash shell: bash
run: | run: |
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
- name: Prepare manifests for LibriSpeech test-clean and test-other
shell: bash
run: |
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
- name: Cache LibriSpeech test-clean and test-other fbank features
id: libri-test-clean-and-test-other-fbank
uses: actions/cache@v2
with:
path: |
~/tmp/fbank-libri
key: cache-libri-fbank-test-clean-and-test-other
- name: Compute fbank for LibriSpeech test-clean and test-other
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-pre-trained-transducer-stateless.sh .github/scripts/run-pre-trained-transducer-stateless.sh
- name: Display decoding results for transducer_stateless
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/librispeech/ASR/
tree ./transducer_stateless/exp
cd transducer_stateless
echo "results for transducer_stateless"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for transducer_stateless
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless-2022-02-07
path: egs/librispeech/ASR/transducer_stateless/exp/

View File

@ -110,7 +110,9 @@ class Conformer(Transformer):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2 with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)

View File

@ -19,49 +19,62 @@
Usage: Usage:
(1) greedy search (1) greedy search
./transducer_stateless_modified-2/decode.py \ ./transducer_stateless_modified-2/decode.py \
--epoch 89 \ --epoch 89 \
--avg 38 \ --avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \ --exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./transducer_stateless_modified/decode.py \ ./transducer_stateless_modified-2/decode.py \
--epoch 89 \ --epoch 89 \
--avg 38 \ --avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \ --exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./transducer_stateless_modified-2/decode.py \ ./transducer_stateless_modified-2/decode.py \
--epoch 89 \ --epoch 89 \
--avg 38 \ --avg 38 \
--exp-dir ./transducer_stateless_modified/exp \ --exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./transducer_stateless_modified-2/decode.py \
--epoch 89 \
--avg 38 \
--exp-dir ./transducer_stateless_modified-2/exp \
--max-duration 100 \
--decoding-method fast_beam_search \
--beam-size 4 \
--max-contexts 4 \
--max-states 8
""" """
import argparse import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from aishell import AIShell from aishell import AIShell
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder fast_beam_search_one_best,
from joiner import Joiner greedy_search,
from model import Transducer greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -114,6 +127,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -121,8 +135,35 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --decoding-method is beam_search " help="""An integer indicating how many candidates we will keep for each
"and modified_beam_search", frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -132,84 +173,24 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="Maximum number of symbols per frame", help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
) )
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, token_table: k2.SymbolTable,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -230,8 +211,8 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
lexicon: token_table:
It contains the token symbol table and the word symbol table. It maps token ID to a string.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -249,44 +230,80 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size): if params.decoding_method == "fast_beam_search":
# fmt: off hyp_tokens = fast_beam_search_one_best(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] model=model,
# fmt: on decoding_graph=decoding_graph,
if params.decoding_method == "greedy_search": encoder_out=encoder_out,
hyp = greedy_search( encoder_out_lens=encoder_out_lens,
model=model, beam=params.beam,
encoder_out=encoder_out_i, max_contexts=params.max_contexts,
max_sym_per_frame=params.max_sym_per_frame, max_states=params.max_states,
) )
elif params.decoding_method == "beam_search": elif (
hyp = beam_search( params.decoding_method == "greedy_search"
model=model, encoder_out=encoder_out_i, beam=params.beam_size and params.max_sym_per_frame == 1
) ):
elif params.decoding_method == "modified_beam_search": hyp_tokens = greedy_search_batch(
hyp = modified_beam_search( model=model,
model=model, encoder_out=encoder_out_i, beam=params.beam_size encoder_out=encoder_out,
) encoder_out_lens=encoder_out_lens,
else: )
raise ValueError( elif params.decoding_method == "modified_beam_search":
f"Unsupported decoding method: {params.decoding_method}" hyp_tokens = modified_beam_search(
) model=model,
hyps.append([lexicon.token_table[i] for i in hyp]) encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_tokens.append(hyp)
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -297,6 +314,11 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
token_table:
It maps a token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -312,9 +334,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -323,7 +345,8 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
lexicon=lexicon, token_table=token_table,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -358,6 +381,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
@ -408,13 +432,21 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -456,6 +488,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -472,7 +509,8 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
lexicon=lexicon, token_table=lexicon.token_table,
decoding_graph=decoding_graph,
) )
save_results( save_results(
@ -484,8 +522,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -19,7 +19,7 @@
""" """
Usage: Usage:
# greedy search (1) greedy search
./transducer_stateless_modified-2/pretrained.py \ ./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \ --checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \ --lang-dir /path/to/lang_char \
@ -27,7 +27,7 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
# beam search (2) beam search
./transducer_stateless_modified-2/pretrained.py \ ./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \ --checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \ --lang-dir /path/to/lang_char \
@ -36,7 +36,7 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
# modified beam search (3) modified beam search
./transducer_stateless_modified-2/pretrained.py \ ./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \ --checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \ --lang-dir /path/to/lang_char \
@ -45,6 +45,14 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(4) fast beam search
./transducer_stateless_modified-2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
""" """
import argparse import argparse
@ -53,11 +61,13 @@ import math
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -97,6 +107,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -121,7 +132,33 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search and modified_beam_search", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -134,11 +171,10 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="Maximum number of symbols per frame. " help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search", "Use only when --method is greedy_search",
) )
return parser
return parser return parser
@ -225,20 +261,37 @@ def main():
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens x=features, x_lens=feature_lens
) )
num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1: logging.info(f"Using {params.method}")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch( hyp_list = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_list = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
for i in range(encoder_out.size(0)): for i in range(num_waves):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on

View File

@ -19,48 +19,63 @@
Usage: Usage:
(1) greedy search (1) greedy search
./transducer_stateless_modified/decode.py \ ./transducer_stateless_modified/decode.py \
--epoch 64 \ --epoch 14 \
--avg 33 \ --avg 7 \
--exp-dir ./transducer_stateless_modified/exp \ --exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./transducer_stateless_modified/decode.py \ ./transducer_stateless_modified/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless_modified/exp \ --exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./transducer_stateless_modified/decode.py \ ./transducer_stateless_modified/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless_modified/exp \ --exp-dir ./transducer_stateless_modified/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./transducer_stateless_modified/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_modified/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
""" """
import argparse import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder fast_beam_search_one_best,
from joiner import Joiner greedy_search,
from model import Transducer greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -113,6 +128,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -120,7 +136,35 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --decoding-method is beam_search", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -130,84 +174,24 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="Maximum number of symbols per frame", help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
) )
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, token_table: k2.SymbolTable,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -228,8 +212,11 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
lexicon: token_table:
It contains the token symbol table and the word symbol table. It maps token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -247,44 +234,80 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size): if params.decoding_method == "fast_beam_search":
# fmt: off hyp_tokens = fast_beam_search_one_best(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] model=model,
# fmt: on decoding_graph=decoding_graph,
if params.decoding_method == "greedy_search": encoder_out=encoder_out,
hyp = greedy_search( encoder_out_lens=encoder_out_lens,
model=model, beam=params.beam,
encoder_out=encoder_out_i, max_contexts=params.max_contexts,
max_sym_per_frame=params.max_sym_per_frame, max_states=params.max_states,
) )
elif params.decoding_method == "beam_search": elif (
hyp = beam_search( params.decoding_method == "greedy_search"
model=model, encoder_out=encoder_out_i, beam=params.beam_size and params.max_sym_per_frame == 1
) ):
elif params.decoding_method == "modified_beam_search": hyp_tokens = greedy_search_batch(
hyp = modified_beam_search( model=model,
model=model, encoder_out=encoder_out_i, beam=params.beam_size encoder_out=encoder_out,
) encoder_out_lens=encoder_out_lens,
else: )
raise ValueError( elif params.decoding_method == "modified_beam_search":
f"Unsupported decoding method: {params.decoding_method}" hyp_tokens = modified_beam_search(
) model=model,
hyps.append([lexicon.token_table[i] for i in hyp]) encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_tokens.append(hyp)
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -295,6 +318,11 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
token_table:
It maps a token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -310,9 +338,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -321,7 +349,8 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
lexicon=lexicon, token_table=token_table,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -356,6 +385,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
@ -406,13 +436,21 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -452,6 +490,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -467,7 +510,8 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
lexicon=lexicon, token_table=lexicon.token_table,
decoding_graph=decoding_graph,
) )
save_results( save_results(
@ -479,8 +523,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -19,7 +19,7 @@
""" """
Usage: Usage:
# greedy search (1) greedy search
./transducer_stateless_modified/pretrained.py \ ./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \ --checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \ --lang-dir /path/to/lang_char \
@ -27,7 +27,7 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
# beam search (2) beam search
./transducer_stateless_modified/pretrained.py \ ./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \ --checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \ --lang-dir /path/to/lang_char \
@ -36,7 +36,7 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
# modified beam search (3) modified beam search
./transducer_stateless_modified/pretrained.py \ ./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \ --checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \ --lang-dir /path/to/lang_char \
@ -45,6 +45,14 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(4) fast beam search
./transducer_stateless_modified/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
""" """
import argparse import argparse
@ -53,11 +61,13 @@ import math
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -97,6 +107,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -121,7 +132,33 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search and modified_beam_search", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -134,11 +171,10 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="Maximum number of symbols per frame. " help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search", "Use only when --method is greedy_search",
) )
return parser
return parser return parser
@ -225,20 +261,37 @@ def main():
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens x=features, x_lens=feature_lens
) )
num_waves = encoder_out.size(0)
hyp_list = [] hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1: logging.info(f"Using {params.method}")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch( hyp_list = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_list = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
for i in range(encoder_out.size(0)): for i in range(num_waves):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on

View File

@ -27,6 +27,149 @@ from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts from icefall.utils import get_texts
def fast_beam_search_one_best(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.
This is the best result we can achieve for any nbest based rescoring
methods.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
ref_texts:
A list-of-list of integers containing the reference transcripts.
If the decoding_graph is a trivial_graph, the integer ID is the
BPE token ID.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search( def fast_beam_search(
model: Transducer, model: Transducer,
decoding_graph: k2.Fsa, decoding_graph: k2.Fsa,
@ -35,8 +178,7 @@ def fast_beam_search(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
use_max: bool = False, ) -> k2.Fsa:
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
Args: Args:
@ -55,11 +197,10 @@ def fast_beam_search(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns: Returns:
Return the decoded result. Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
lattice is actually an acceptor.
""" """
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -92,7 +233,7 @@ def fast_beam_search(
# (shape.NumElements(), 1, encoder_out_dim) # (shape.NumElements(), 1, encoder_out_dim)
# fmt: off # fmt: off
current_encoder_out = torch.index_select( current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).long() encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
# in some old versions of pytorch, the type of index requires # in some old versions of pytorch, the type of index requires
# to be LongTensor. In the newest version of pytorch, the type # to be LongTensor. In the newest version of pytorch, the type
# of index can be IntTensor or LongTensor. For supporting the # of index can be IntTensor or LongTensor. For supporting the
@ -109,67 +250,7 @@ def fast_beam_search(
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
if use_max: return lattice
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
else:
num_paths = 200
use_double_scores = True
nbest_scale = 0.8
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores, log_semiring=True
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def greedy_search( def greedy_search(
@ -193,10 +274,10 @@ def greedy_search(
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = model.device device = next(model.parameters()).device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64 [blank_id] * context_size, device=device, dtype=torch.int64
@ -230,7 +311,7 @@ def greedy_search(
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id and y != unk_id: if y not in (blank_id, unk_id):
hyp.append(y) hyp.append(y)
decoder_input = torch.tensor( decoder_input = torch.tensor(
[hyp[-context_size:]], device=device [hyp[-context_size:]], device=device
@ -249,7 +330,9 @@ def greedy_search(
def greedy_search_batch( def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]: ) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
@ -257,6 +340,9 @@ def greedy_search_batch(
The transducer model. The transducer model.
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1. Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns: Returns:
Return a list-of-list of token IDs containing the decoded results. Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0). len(ans) equals to encoder_out.size(0).
@ -264,28 +350,48 @@ def greedy_search_batch(
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0) device = next(model.parameters()).device
T = encoder_out.size(1)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)] batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor( decoder_input = torch.tensor(
hyps, hyps,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) # (batch_size, context_size) ) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim) # decoder_out: (N, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa encoder_out = packed_encoder_out.data
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits'shape (batch_size, 1, 1, vocab_size) # logits'shape (batch_size, 1, 1, vocab_size)
@ -294,12 +400,12 @@ def greedy_search_batch(
y = logits.argmax(dim=1).tolist() y = logits.argmax(dim=1).tolist()
emitted = False emitted = False
for i, v in enumerate(y): for i, v in enumerate(y):
if v != blank_id and v != unk_id: if v not in (blank_id, unk_id):
hyps[i].append(v) hyps[i].append(v)
emitted = True emitted = True
if emitted: if emitted:
# update decoder output # update decoder output
decoder_input = [h[-context_size:] for h in hyps] decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor( decoder_input = torch.tensor(
decoder_input, decoder_input,
device=device, device=device,
@ -307,7 +413,12 @@ def greedy_search_batch(
) )
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
ans = [h[context_size:] for h in hyps] sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans
@ -472,6 +583,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
use_max: bool = False, use_max: bool = False,
) -> List[List[int]]: ) -> List[List[int]]:
@ -482,6 +594,9 @@ def modified_beam_search(
The transducer model. The transducer model.
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C). Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
use_max: use_max:
@ -492,16 +607,27 @@ def modified_beam_search(
for the i-th utterance. for the i-th utterance.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
batch_size = encoder_out.size(0) packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
T = encoder_out.size(1) input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size): batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
@ -510,9 +636,20 @@ def modified_beam_search(
use_max=use_max, use_max=use_max,
) )
for t in range(T): encoder_out = packed_encoder_out.data
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device) hyps_shape = _get_hyps_shape(B).to(device)
@ -577,15 +714,21 @@ def modified_beam_search(
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
if new_token != blank_id and new_token != unk_id: if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_log_prob = topk_log_probs[k] new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans
@ -622,10 +765,10 @@ def _deprecated_modified_beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
T = encoder_out.size(1) T = encoder_out.size(1)
@ -691,7 +834,7 @@ def _deprecated_modified_beam_search(
hyp = A[topk_hyp_indexes[i]] hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[i] new_token = topk_token_indexes[i]
if new_token != blank_id and new_token != unk_id: if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_log_prob = topk_log_probs[i] new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
@ -732,10 +875,10 @@ def beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, [blank_id] * context_size,
@ -818,7 +961,7 @@ def beam_search(
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()): for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id or i == unk_id: if i in (blank_id, unk_id):
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v

View File

@ -19,53 +19,53 @@
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--max-duration 1500 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
(5) fast beam search using LG (5) fast beam search using LG
./pruned_transducer_stateless/decode.py \ ./pruned_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--use-LG True \ --use-LG True \
--use-max False \ --use-max False \
--max-duration 1500 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 8 \ --beam 8 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
""" """
@ -82,7 +82,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -307,7 +307,7 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -315,7 +315,6 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
use_max=params.use_max,
) )
if params.use_LG: if params.use_LG:
for hyp in hyp_tokens: for hyp in hyp_tokens:
@ -330,6 +329,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -337,6 +337,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
use_max=params.use_max, use_max=params.use_max,
) )
@ -421,9 +422,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):

View File

@ -25,7 +25,7 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav \
(1) beam search (2) beam search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
@ -34,6 +34,24 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav \
(3) modified beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
(4) fast beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`. You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by
@ -46,12 +64,14 @@ import logging
import math import math
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -77,9 +97,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
help="""Path to bpe.model. help="""Path to bpe.model.""",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -90,6 +108,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -114,7 +133,33 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search and modified_beam_search", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -230,10 +275,25 @@ def main():
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
if params.method == "modified_beam_search":
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
@ -243,6 +303,7 @@ def main():
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())

View File

@ -335,7 +335,9 @@ def greedy_search(
def greedy_search_batch( def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]: ) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
@ -343,6 +345,9 @@ def greedy_search_batch(
The transducer model. The transducer model.
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1. Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns: Returns:
Return a list-of-list of token IDs containing the decoded results. Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0). len(ans) equals to encoder_out.size(0).
@ -350,31 +355,49 @@ def greedy_search_batch(
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = next(model.parameters()).device packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0) device = next(model.parameters()).device
T = encoder_out.size(1)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)] batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor( decoder_input = torch.tensor(
hyps, hyps,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) # (batch_size, context_size) ) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out) # decoder_out: (N, 1, decoder_out_dim)
# decoder_out: (batch_size, 1, decoder_out_dim) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner( logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False current_encoder_out, decoder_out.unsqueeze(1), project_input=False
) )
@ -390,7 +413,7 @@ def greedy_search_batch(
emitted = True emitted = True
if emitted: if emitted:
# update decoder output # update decoder output
decoder_input = [h[-context_size:] for h in hyps] decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor( decoder_input = torch.tensor(
decoder_input, decoder_input,
device=device, device=device,
@ -399,7 +422,12 @@ def greedy_search_batch(
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
ans = [h[context_size:] for h in hyps] sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans
@ -557,6 +585,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -566,6 +595,9 @@ def modified_beam_search(
The transducer model. The transducer model.
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C). Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
Returns: Returns:
@ -573,16 +605,27 @@ def modified_beam_search(
for the i-th utterance. for the i-th utterance.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
batch_size = encoder_out.size(0) packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
T = encoder_out.size(1) input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = next(model.parameters()).device device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size): batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
@ -590,11 +633,20 @@ def modified_beam_search(
) )
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
for t in range(T): offset = 0
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device) hyps_shape = _get_hyps_shape(B).to(device)
@ -668,8 +720,14 @@ def modified_beam_search(
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans

View File

@ -22,15 +22,15 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
@ -270,6 +270,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -277,6 +278,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
@ -356,9 +358,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):

View File

@ -22,15 +22,15 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./pruned_transducer_stateless3/decode-giga.py \ ./pruned_transducer_stateless3/decode-giga.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 1500 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
@ -224,8 +224,8 @@ def get_parser():
def post_processing( def post_processing(
results: List[Tuple[List[List[str]], List[List[str]]]], results: List[Tuple[List[str], List[str]]],
) -> List[Tuple[List[List[str]], List[List[str]]]]: ) -> List[Tuple[List[str], List[str]]]:
new_results = [] new_results = []
for ref, hyp in results: for ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split() new_ref = asr_text_post_processing(" ".join(ref)).split()
@ -415,9 +415,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):

View File

@ -22,15 +22,15 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless3/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 1500 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
@ -307,6 +307,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -314,6 +315,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
@ -403,9 +405,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):

View File

@ -22,16 +22,16 @@ Usage:
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -39,8 +39,8 @@ Usage:
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
@ -48,8 +48,8 @@ Usage:
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 1500 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
@ -70,7 +70,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -266,7 +266,7 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -284,6 +284,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -291,6 +292,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
@ -370,9 +372,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):

View File

@ -22,6 +22,235 @@ import k2
import torch import torch
from model import Transducer from model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
def fast_beam_search_one_best(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.
This is the best result we can achieve for any nbest based rescoring
methods.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
ref_texts:
A list-of-list of integers containing the reference transcripts.
If the decoding_graph is a trivial_graph, the integer ID is the
BPE token ID.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
lattice is actually an acceptor.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
encoder_out_len = torch.ones(1, dtype=torch.int32)
decoder_out_len = torch.ones(1, dtype=torch.int32)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
) # (N, vocab_size)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
return lattice
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
@ -104,7 +333,9 @@ def greedy_search(
def greedy_search_batch( def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]: ) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
@ -112,6 +343,9 @@ def greedy_search_batch(
The transducer model. The transducer model.
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1. Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns: Returns:
Return a list-of-list of token IDs containing the decoded results. Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0). len(ans) equals to encoder_out.size(0).
@ -119,32 +353,54 @@ def greedy_search_batch(
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0) device = next(model.parameters()).device
T = encoder_out.size(1)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)] batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor( decoder_input = torch.tensor(
hyps, hyps,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) # (batch_size, context_size) ) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim) # decoder_out: (N, 1, decoder_out_dim)
encoder_out_len = torch.ones(batch_size, dtype=torch.int32) encoder_out_len = torch.ones(1, dtype=torch.int32)
decoder_out_len = torch.ones(batch_size, dtype=torch.int32) decoder_out_len = torch.ones(1, dtype=torch.int32)
for t in range(T): encoder_out = packed_encoder_out.data
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner( logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len current_encoder_out,
decoder_out,
encoder_out_len.expand(batch_size),
decoder_out_len.expand(batch_size),
) # (batch_size, vocab_size) ) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape assert logits.ndim == 2, logits.shape
@ -157,7 +413,7 @@ def greedy_search_batch(
if emitted: if emitted:
# update decoder output # update decoder output
decoder_input = [h[-context_size:] for h in hyps] decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor( decoder_input = torch.tensor(
decoder_input, decoder_input,
device=device, device=device,
@ -168,7 +424,12 @@ def greedy_search_batch(
need_pad=False, need_pad=False,
) # (batch_size, 1, decoder_out_dim) ) # (batch_size, 1, decoder_out_dim)
ans = [h[context_size:] for h in hyps] sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans
@ -415,6 +676,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
@ -424,6 +686,9 @@ def modified_beam_search(
The transducer model. The transducer model.
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C). Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
Returns: Returns:
@ -431,15 +696,26 @@ def modified_beam_search(
for the i-th utterance. for the i-th utterance.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
batch_size = encoder_out.size(0) packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
T = encoder_out.size(1) input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size): batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
@ -449,9 +725,20 @@ def modified_beam_search(
encoder_out_len = torch.tensor([1]) encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1]) decoder_out_len = torch.tensor([1])
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa encoder_out = packed_encoder_out.data
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1)
# current_encoder_out's shape is: (batch_size, 1, encoder_out_dim) # current_encoder_out's shape is: (batch_size, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device) hyps_shape = _get_hyps_shape(B).to(device)
@ -524,8 +811,14 @@ def modified_beam_search(
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans

View File

@ -19,29 +19,40 @@
Usage: Usage:
(1) greedy search (1) greedy search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
""" """
@ -49,14 +60,16 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -115,6 +128,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -122,8 +136,35 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
beam_search or modified_beam_search""", fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -149,6 +190,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -171,6 +213,9 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -188,24 +233,44 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyp_list: List[List[int]] = []
if ( hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
): ):
hyp_list = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -226,14 +291,20 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyp_list.append(hyp) hyps.append(sp.decode(hyp).split())
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
@ -241,6 +312,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -253,6 +325,9 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -268,9 +343,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -280,6 +355,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -360,13 +436,21 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -408,6 +492,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +517,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
) )
save_results( save_results(

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
padding_idx=blank_id, padding_idx=blank_id,
) )
self.blank_id = blank_id self.blank_id = blank_id
self.vocab_size = vocab_size
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size

View File

@ -19,30 +19,39 @@ Usage:
(1) greedy search (1) greedy search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \ --method greedy_search \
--max-sym-per-frame 1 \ --max-sym-per-frame 1 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(2) beam search (2) beam search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(3) modified beam search (3) modified beam search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(4) fast beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./transducer_stateless/exp/epoch-xx.pt`. You can also use `./transducer_stateless/exp/epoch-xx.pt`.
@ -56,12 +65,14 @@ import logging
import math import math
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -87,9 +98,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
help="""Path to bpe.model. help="""Path to bpe.model.""",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -100,6 +109,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -124,7 +134,33 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search and modified_beam_search ", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -241,15 +277,28 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1: if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch( hyp_list = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_list = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
else: else:

View File

@ -22,15 +22,15 @@ Usage:
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless2/exp \ --exp-dir ./transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./transducer_stateless2/decode.py \ ./transducer_stateless2/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless2/exp \ --exp-dir ./transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -39,9 +39,20 @@ Usage:
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless2/exp \ --exp-dir ./transducer_stateless2/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./transducer_stateless2/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
""" """
@ -49,14 +60,16 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -115,6 +128,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -122,8 +136,35 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
beam_search or modified_beam_search""", fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -149,6 +190,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -171,6 +213,9 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -188,24 +233,44 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyp_list: List[List[int]] = []
if ( hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
): ):
hyp_list = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -226,14 +291,20 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyp_list.append(hyp) hyps.append(sp.decode(hyp).split())
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
@ -241,6 +312,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -253,6 +325,9 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -268,9 +343,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -280,6 +355,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -360,13 +436,21 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -408,6 +492,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +517,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
) )
save_results( save_results(

View File

@ -19,30 +19,39 @@ Usage:
(1) greedy search (1) greedy search
./transducer_stateless2/pretrained.py \ ./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \ --method greedy_search \
--max-sym-per-frame 1 \ --max-sym-per-frame 1 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(2) beam search (2) beam search
./transducer_stateless2/pretrained.py \ ./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(3) modified beam search (3) modified beam search
./transducer_stateless2/pretrained.py \ ./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav
(4) fast beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./transducer_stateless2/exp/epoch-xx.pt`. You can also use `./transducer_stateless2/exp/epoch-xx.pt`.
@ -56,12 +65,14 @@ import logging
import math import math
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -87,9 +98,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
help="""Path to bpe.model. help="""Path to bpe.model.""",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -100,6 +109,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -124,7 +134,33 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search and modified_beam_search ", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -241,15 +277,28 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1: if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch( hyp_list = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_list = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
else: else:

View File

@ -22,17 +22,37 @@ Usage:
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless_multi_datasets/exp \ --exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search (not recommended)
./transducer_stateless_multi_datasets/decode.py \ ./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless_multi_datasets/exp \ --exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 100 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search
./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_stateless_multi_datasets/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
""" """
@ -40,14 +60,16 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -107,6 +129,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -114,8 +137,35 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
beam_search or modified_beam_search""", fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -141,6 +191,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -163,6 +214,9 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -180,24 +234,44 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyp_list = []
batch_size = encoder_out.size(0)
if ( hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
): ):
hyp_list = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0)
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -218,14 +292,20 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyp_list.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp).split())
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
@ -233,6 +313,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -245,6 +326,9 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -260,9 +344,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 50
else: else:
log_interval = 2 log_interval = 10
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -272,6 +356,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -352,13 +437,21 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -402,6 +495,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -423,6 +521,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
) )
save_results( save_results(

View File

@ -44,6 +44,15 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(4) fast beam search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./transducer_stateless_multi_datasets/exp/epoch-xx.pt`. You can also use `./transducer_stateless_multi_datasets/exp/epoch-xx.pt`.
Note: ./transducer_stateless_multi_datasets/exp/pretrained.pt is generated by Note: ./transducer_stateless_multi_datasets/exp/pretrained.pt is generated by
@ -56,12 +65,14 @@ import logging
import math import math
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -87,9 +98,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
help="""Path to bpe.model. help="""Path to bpe.model.""",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -100,6 +109,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -124,7 +134,33 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search and modified_beam_search ", help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -241,18 +277,30 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1: if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch( hyp_list = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_list = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off

View File

@ -69,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import TedLiumAsrDataModule from asr_datamodule import TedLiumAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -237,7 +237,7 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -255,6 +255,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -262,6 +263,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):

View File

@ -72,23 +72,16 @@ import k2
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -185,76 +178,16 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -354,7 +287,7 @@ def main():
logging.info(msg) logging.info(msg)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -372,6 +305,7 @@ def main():
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -379,6 +313,7 @@ def main():
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):