mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Ignore padding frames during RNN-T decoding. (#358)
* Ignore padding frames during RNN-T decoding. * Fix outdated decoding code. * Minor fixes.
This commit is contained in:
parent
bc284e88e6
commit
aeb8986e35
@ -33,7 +33,7 @@ for sym in 1 2 3; do
|
|||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search; do
|
for method in fast_beam_search modified_beam_search beam_search; do
|
||||||
log "$method"
|
log "$method"
|
||||||
|
|
||||||
./pruned_transducer_stateless/pretrained.py \
|
./pruned_transducer_stateless/pretrained.py \
|
||||||
@ -47,7 +47,8 @@ for method in modified_beam_search beam_search; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
mkdir -p pruned_transducer_stateless/exp
|
mkdir -p pruned_transducer_stateless/exp
|
||||||
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt
|
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt
|
||||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
@ -58,9 +59,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
|||||||
log "Decoding test-clean and test-other"
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
# use a small value for decoding with CPU
|
# use a small value for decoding with CPU
|
||||||
max_duration=50
|
max_duration=100
|
||||||
|
|
||||||
for method in greedy_search fast_beam_search; do
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
log "Decoding with $method"
|
log "Decoding with $method"
|
||||||
|
|
||||||
./pruned_transducer_stateless/decode.py \
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
@ -51,7 +51,8 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
mkdir -p pruned_transducer_stateless2/exp
|
mkdir -p pruned_transducer_stateless2/exp
|
||||||
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt
|
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt
|
||||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
@ -62,9 +63,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
|||||||
log "Decoding test-clean and test-other"
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
# use a small value for decoding with CPU
|
# use a small value for decoding with CPU
|
||||||
max_duration=50
|
max_duration=100
|
||||||
|
|
||||||
for method in greedy_search fast_beam_search; do
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
log "Decoding with $method"
|
log "Decoding with $method"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
@ -51,7 +51,8 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
mkdir -p pruned_transducer_stateless3/exp
|
mkdir -p pruned_transducer_stateless3/exp
|
||||||
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
|
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
|
||||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
@ -62,9 +63,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
|||||||
log "Decoding test-clean and test-other"
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
# use a small value for decoding with CPU
|
# use a small value for decoding with CPU
|
||||||
max_duration=50
|
max_duration=100
|
||||||
|
|
||||||
for method in greedy_search fast_beam_search; do
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
log "Decoding with $method"
|
log "Decoding with $method"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
@ -33,7 +33,7 @@ for sym in 1 2 3; do
|
|||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search; do
|
for method in fast_beam_search modified_beam_search beam_search; do
|
||||||
log "$method"
|
log "$method"
|
||||||
|
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
@ -47,7 +47,8 @@ for method in modified_beam_search beam_search; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
mkdir -p transducer_stateless2/exp
|
mkdir -p transducer_stateless2/exp
|
||||||
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt
|
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt
|
||||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
@ -58,9 +59,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
|
|||||||
log "Decoding test-clean and test-other"
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
# use a small value for decoding with CPU
|
# use a small value for decoding with CPU
|
||||||
max_duration=50
|
max_duration=100
|
||||||
|
|
||||||
for method in greedy_search modified_beam_search; do
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
log "Decoding with $method"
|
log "Decoding with $method"
|
||||||
|
|
||||||
./transducer_stateless2/decode.py \
|
./transducer_stateless2/decode.py \
|
||||||
|
@ -33,7 +33,7 @@ for sym in 1 2 3; do
|
|||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search; do
|
for method in modified_beam_search beam_search fast_beam_search; do
|
||||||
log "$method"
|
log "$method"
|
||||||
|
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
@ -45,3 +45,32 @@ for method in modified_beam_search beam_search; do
|
|||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
|
mkdir -p transducer_stateless_multi_datasets/exp
|
||||||
|
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
|
||||||
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|
||||||
|
ls -lh data
|
||||||
|
ls -lh transducer_stateless_multi_datasets/exp
|
||||||
|
|
||||||
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
|
# use a small value for decoding with CPU
|
||||||
|
max_duration=100
|
||||||
|
|
||||||
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
log "Decoding with $method"
|
||||||
|
|
||||||
|
./transducer_stateless_multi_datasets/decode.py \
|
||||||
|
--decoding-method $method \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration $max_duration \
|
||||||
|
--exp-dir transducer_stateless_multi_datasets/exp
|
||||||
|
done
|
||||||
|
|
||||||
|
rm transducer_stateless_multi_datasets/exp/*.pt
|
||||||
|
fi
|
||||||
|
@ -33,7 +33,7 @@ for sym in 1 2 3; do
|
|||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search; do
|
for method in modified_beam_search beam_search fast_beam_search; do
|
||||||
log "$method"
|
log "$method"
|
||||||
|
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
@ -45,3 +45,32 @@ for method in modified_beam_search beam_search; do
|
|||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
|
mkdir -p transducer_stateless_multi_datasets/exp
|
||||||
|
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
|
||||||
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|
||||||
|
ls -lh data
|
||||||
|
ls -lh transducer_stateless_multi_datasets/exp
|
||||||
|
|
||||||
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
|
# use a small value for decoding with CPU
|
||||||
|
max_duration=100
|
||||||
|
|
||||||
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
log "Decoding with $method"
|
||||||
|
|
||||||
|
./transducer_stateless_multi_datasets/decode.py \
|
||||||
|
--decoding-method $method \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration $max_duration \
|
||||||
|
--exp-dir transducer_stateless_multi_datasets/exp
|
||||||
|
done
|
||||||
|
|
||||||
|
rm transducer_stateless_multi_datasets/exp/*.pt
|
||||||
|
fi
|
||||||
|
@ -33,7 +33,7 @@ for sym in 1 2 3; do
|
|||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search; do
|
for method in fast_beam_search modified_beam_search beam_search; do
|
||||||
log "$method"
|
log "$method"
|
||||||
|
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
@ -46,15 +46,31 @@ for method in modified_beam_search beam_search; do
|
|||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search; do
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
log "$method"
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
|
mkdir -p transducer_stateless/exp
|
||||||
|
ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless/exp/epoch-999.pt
|
||||||
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
ls -lh data
|
||||||
--method $method \
|
ls -lh transducer_stateless/exp
|
||||||
--beam-size 4 \
|
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
log "Decoding test-clean and test-other"
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
# use a small value for decoding with CPU
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
max_duration=100
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
|
||||||
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
log "Decoding with $method"
|
||||||
|
|
||||||
|
./transducer_stateless/decode.py \
|
||||||
|
--decoding-method $method \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration $max_duration \
|
||||||
|
--exp-dir transducer_stateless/exp
|
||||||
done
|
done
|
||||||
|
|
||||||
|
rm transducer_stateless/exp/*.pt
|
||||||
|
fi
|
||||||
|
14
.github/workflows/run-librispeech-2022-03-12.yml
vendored
14
.github/workflows/run-librispeech-2022-03-12.yml
vendored
@ -35,7 +35,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_2022_03_12:
|
run_librispeech_2022_03_12:
|
||||||
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -107,11 +107,11 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||||
|
|
||||||
|
|
||||||
- name: Inference with pre-trained model
|
- name: Inference with pre-trained model
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
run: |
|
run: |
|
||||||
mkdir -p egs/librispeech/ASR/data
|
mkdir -p egs/librispeech/ASR/data
|
||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
@ -124,8 +124,8 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
|
.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
|
||||||
|
|
||||||
- name: Display decoding results
|
- name: Display decoding results for pruned_transducer_stateless
|
||||||
if: github.event_name == 'schedule'
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
cd egs/librispeech/ASR/
|
cd egs/librispeech/ASR/
|
||||||
@ -141,9 +141,13 @@ jobs:
|
|||||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===modified beam search==="
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
- name: Upload decoding results for pruned_transducer_stateless
|
- name: Upload decoding results for pruned_transducer_stateless
|
||||||
uses: actions/upload-artifact@v2
|
uses: actions/upload-artifact@v2
|
||||||
if: github.event_name == 'schedule'
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
with:
|
with:
|
||||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
|
||||||
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/
|
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/
|
||||||
|
47
.github/workflows/run-librispeech-2022-04-29.yml
vendored
47
.github/workflows/run-librispeech-2022-04-29.yml
vendored
@ -35,7 +35,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_2022_04_29:
|
run_librispeech_2022_04_29:
|
||||||
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -111,6 +111,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
run: |
|
run: |
|
||||||
mkdir -p egs/librispeech/ASR/data
|
mkdir -p egs/librispeech/ASR/data
|
||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
@ -125,44 +126,54 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
|
.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
|
||||||
|
|
||||||
- name: Display decoding results
|
- name: Display decoding results for pruned_transducer_stateless2
|
||||||
if: github.event_name == 'schedule'
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
tree pruned_transducer_stateless2/exp
|
tree pruned_transducer_stateless2/exp
|
||||||
cd pruned_transducer_stateless2
|
cd pruned_transducer_stateless2/exp
|
||||||
echo "results for pruned_transducer_stateless2"
|
|
||||||
echo "===greedy search==="
|
echo "===greedy search==="
|
||||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
echo "===fast_beam_search==="
|
echo "===fast_beam_search==="
|
||||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
cd ../
|
echo "===modified beam search==="
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
- name: Display decoding results for pruned_transducer_stateless3
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd egs/librispeech/ASR
|
||||||
tree pruned_transducer_stateless3/exp
|
tree pruned_transducer_stateless3/exp
|
||||||
cd pruned_transducer_stateless3
|
cd pruned_transducer_stateless3/exp
|
||||||
echo "results for pruned_transducer_stateless3"
|
|
||||||
echo "===greedy search==="
|
echo "===greedy search==="
|
||||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
echo "===fast_beam_search==="
|
echo "===fast_beam_search==="
|
||||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===modified beam search==="
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
- name: Upload decoding results for pruned_transducer_stateless2
|
- name: Upload decoding results for pruned_transducer_stateless2
|
||||||
uses: actions/upload-artifact@v2
|
uses: actions/upload-artifact@v2
|
||||||
if: github.event_name == 'schedule'
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
with:
|
with:
|
||||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
|
||||||
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
|
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
|
||||||
|
|
||||||
- name: Upload decoding results for pruned_transducer_stateless3
|
- name: Upload decoding results for pruned_transducer_stateless3
|
||||||
uses: actions/upload-artifact@v2
|
uses: actions/upload-artifact@v2
|
||||||
if: github.event_name == 'schedule'
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
with:
|
with:
|
||||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
|
||||||
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
|
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
|
||||||
|
@ -35,7 +35,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_2022_04_19:
|
run_librispeech_2022_04_19:
|
||||||
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -111,6 +111,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
run: |
|
run: |
|
||||||
mkdir -p egs/librispeech/ASR/data
|
mkdir -p egs/librispeech/ASR/data
|
||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
@ -124,7 +125,7 @@ jobs:
|
|||||||
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
|
.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
|
||||||
|
|
||||||
- name: Display decoding results
|
- name: Display decoding results
|
||||||
if: github.event_name == 'schedule'
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
cd egs/librispeech/ASR/
|
cd egs/librispeech/ASR/
|
||||||
@ -136,13 +137,17 @@ jobs:
|
|||||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===fast_beam_search==="
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
echo "===modified_beam_search==="
|
echo "===modified_beam_search==="
|
||||||
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
- name: Upload decoding results for transducer_stateless2
|
- name: Upload decoding results for transducer_stateless2
|
||||||
uses: actions/upload-artifact@v2
|
uses: actions/upload-artifact@v2
|
||||||
if: github.event_name == 'schedule'
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
with:
|
with:
|
||||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
|
||||||
path: egs/librispeech/ASR/transducer_stateless2/exp/
|
path: egs/librispeech/ASR/transducer_stateless2/exp/
|
||||||
|
@ -23,9 +23,18 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
types: [labeled]
|
types: [labeled]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# minute (0-59)
|
||||||
|
# hour (0-23)
|
||||||
|
# day of the month (1-31)
|
||||||
|
# month (1-12)
|
||||||
|
# day of the week (0-6)
|
||||||
|
# nightly build at 15:50 UTC time every day
|
||||||
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
|
run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
|
||||||
if: github.event.label.name == 'ready' || github.event_name == 'push'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -64,11 +73,80 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
.github/scripts/install-kaldifeat.sh
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
- name: Inference with pre-trained model
|
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||||
|
id: libri-test-clean-and-test-other-data
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/download
|
||||||
|
key: cache-libri-test-clean-and-test-other
|
||||||
|
|
||||||
|
- name: Download LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||||
|
|
||||||
|
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||||
|
|
||||||
|
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||||
|
id: libri-test-clean-and-test-other-fbank
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/fbank-libri
|
||||||
|
key: cache-libri-fbank-test-clean-and-test-other
|
||||||
|
|
||||||
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
mkdir -p egs/librispeech/ASR/data
|
||||||
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree sox
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
|
.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
|
||||||
|
|
||||||
|
- name: Display decoding results for transducer_stateless_multi_datasets
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd egs/librispeech/ASR/
|
||||||
|
tree ./transducer_stateless_multi_datasets/exp
|
||||||
|
|
||||||
|
cd transducer_stateless_multi_datasets
|
||||||
|
echo "results for transducer_stateless_multi_datasets"
|
||||||
|
echo "===greedy search==="
|
||||||
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===fast_beam_search==="
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===modified beam search==="
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
- name: Upload decoding results for transducer_stateless_multi_datasets
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
with:
|
||||||
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
|
||||||
|
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
|
||||||
|
@ -23,9 +23,18 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
types: [labeled]
|
types: [labeled]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# minute (0-59)
|
||||||
|
# hour (0-23)
|
||||||
|
# day of the month (1-31)
|
||||||
|
# month (1-12)
|
||||||
|
# day of the week (0-6)
|
||||||
|
# nightly build at 15:50 UTC time every day
|
||||||
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
|
run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
|
||||||
if: github.event.label.name == 'ready' || github.event_name == 'push'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -64,11 +73,80 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
.github/scripts/install-kaldifeat.sh
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
- name: Inference with pre-trained model
|
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||||
|
id: libri-test-clean-and-test-other-data
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/download
|
||||||
|
key: cache-libri-test-clean-and-test-other
|
||||||
|
|
||||||
|
- name: Download LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||||
|
|
||||||
|
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||||
|
|
||||||
|
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||||
|
id: libri-test-clean-and-test-other-fbank
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/fbank-libri
|
||||||
|
key: cache-libri-fbank-test-clean-and-test-other
|
||||||
|
|
||||||
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
mkdir -p egs/librispeech/ASR/data
|
||||||
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree sox
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
|
.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
|
||||||
|
|
||||||
|
- name: Display decoding results for transducer_stateless_multi_datasets
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd egs/librispeech/ASR/
|
||||||
|
tree ./transducer_stateless_multi_datasets/exp
|
||||||
|
|
||||||
|
cd transducer_stateless_multi_datasets
|
||||||
|
echo "results for transducer_stateless_multi_datasets"
|
||||||
|
echo "===greedy search==="
|
||||||
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===fast_beam_search==="
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===modified beam search==="
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
- name: Upload decoding results for transducer_stateless_multi_datasets
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
with:
|
||||||
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
|
||||||
|
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
name: run-pre-trained-trandsucer-stateless
|
name: run-pre-trained-transducer-stateless
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -23,9 +23,18 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
types: [labeled]
|
types: [labeled]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# minute (0-59)
|
||||||
|
# hour (0-23)
|
||||||
|
# day of the month (1-31)
|
||||||
|
# month (1-12)
|
||||||
|
# day of the week (0-6)
|
||||||
|
# nightly build at 15:50 UTC time every day
|
||||||
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_pre_trained_transducer_stateless:
|
run_pre_trained_transducer_stateless:
|
||||||
if: github.event.label.name == 'ready' || github.event_name == 'push'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -64,11 +73,80 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
.github/scripts/install-kaldifeat.sh
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
- name: Inference with pre-trained model
|
- name: Cache LibriSpeech test-clean and test-other datasets
|
||||||
|
id: libri-test-clean-and-test-other-data
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/download
|
||||||
|
key: cache-libri-test-clean-and-test-other
|
||||||
|
|
||||||
|
- name: Download LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
|
||||||
|
|
||||||
|
- name: Prepare manifests for LibriSpeech test-clean and test-other
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
|
||||||
|
|
||||||
|
- name: Cache LibriSpeech test-clean and test-other fbank features
|
||||||
|
id: libri-test-clean-and-test-other-fbank
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/fbank-libri
|
||||||
|
key: cache-libri-fbank-test-clean-and-test-other
|
||||||
|
|
||||||
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
mkdir -p egs/librispeech/ASR/data
|
||||||
|
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
||||||
|
ls -lh egs/librispeech/ASR/data/*
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
sudo apt-get -qq install git-lfs tree sox
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
.github/scripts/run-pre-trained-transducer-stateless.sh
|
.github/scripts/run-pre-trained-transducer-stateless.sh
|
||||||
|
|
||||||
|
- name: Display decoding results for transducer_stateless
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd egs/librispeech/ASR/
|
||||||
|
tree ./transducer_stateless/exp
|
||||||
|
|
||||||
|
cd transducer_stateless
|
||||||
|
echo "results for transducer_stateless"
|
||||||
|
echo "===greedy search==="
|
||||||
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===fast_beam_search==="
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===modified beam search==="
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
- name: Upload decoding results for transducer_stateless
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
with:
|
||||||
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless-2022-02-07
|
||||||
|
path: egs/librispeech/ASR/transducer_stateless/exp/
|
||||||
|
@ -110,6 +110,8 @@ class Conformer(Transformer):
|
|||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
@ -25,8 +25,8 @@ Usage:
|
|||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./transducer_stateless_modified/decode.py \
|
./transducer_stateless_modified-2/decode.py \
|
||||||
--epoch 89 \
|
--epoch 89 \
|
||||||
--avg 38 \
|
--avg 38 \
|
||||||
--exp-dir ./transducer_stateless_modified-2/exp \
|
--exp-dir ./transducer_stateless_modified-2/exp \
|
||||||
@ -38,30 +38,43 @@ Usage:
|
|||||||
./transducer_stateless_modified-2/decode.py \
|
./transducer_stateless_modified-2/decode.py \
|
||||||
--epoch 89 \
|
--epoch 89 \
|
||||||
--avg 38 \
|
--avg 38 \
|
||||||
--exp-dir ./transducer_stateless_modified/exp \
|
--exp-dir ./transducer_stateless_modified-2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless_modified-2/decode.py \
|
||||||
|
--epoch 89 \
|
||||||
|
--avg 38 \
|
||||||
|
--exp-dir ./transducer_stateless_modified-2/exp \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from aishell import AIShell
|
from aishell import AIShell
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import (
|
||||||
from conformer import Conformer
|
beam_search,
|
||||||
from decoder import Decoder
|
fast_beam_search_one_best,
|
||||||
from joiner import Joiner
|
greedy_search,
|
||||||
from model import Transducer
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -114,6 +127,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,8 +135,35 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --decoding-method is beam_search "
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
"and modified_beam_search",
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -132,84 +173,24 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=1,
|
||||||
help="Maximum number of symbols per frame",
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# parameters for conformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"encoder_out_dim": 512,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
|
||||||
encoder = Conformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_dim=params.encoder_out_dim,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
embedding_dim=params.encoder_out_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
|
||||||
joiner = Joiner(
|
|
||||||
input_dim=params.encoder_out_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = Transducer(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
token_table: k2.SymbolTable,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -230,8 +211,8 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
lexicon:
|
token_table:
|
||||||
It contains the token symbol table and the word symbol table.
|
It maps token ID to a string.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -249,9 +230,36 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
hyps = []
|
|
||||||
batch_size = encoder_out.size(0)
|
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hyp_tokens = []
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
@ -264,29 +272,38 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model,
|
||||||
)
|
encoder_out=encoder_out_i,
|
||||||
elif params.decoding_method == "modified_beam_search":
|
beam=params.beam_size,
|
||||||
hyp = modified_beam_search(
|
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append([lexicon.token_table[i] for i in hyp])
|
hyp_tokens.append(hyp)
|
||||||
|
|
||||||
|
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
token_table: k2.SymbolTable,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -297,6 +314,11 @@ def decode_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
It maps a token ID to a string.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -312,9 +334,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -323,7 +345,8 @@ def decode_dataset(
|
|||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
token_table=token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -358,6 +381,7 @@ def save_results(
|
|||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
@ -408,13 +432,21 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -456,6 +488,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -472,7 +509,8 @@ def main():
|
|||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
token_table=lexicon.token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
@ -484,8 +522,5 @@ def main():
|
|||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
# greedy search
|
(1) greedy search
|
||||||
./transducer_stateless_modified-2/pretrained.py \
|
./transducer_stateless_modified-2/pretrained.py \
|
||||||
--checkpoint /path/to/pretrained.pt \
|
--checkpoint /path/to/pretrained.pt \
|
||||||
--lang-dir /path/to/lang_char \
|
--lang-dir /path/to/lang_char \
|
||||||
@ -27,7 +27,7 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
# beam search
|
(2) beam search
|
||||||
./transducer_stateless_modified-2/pretrained.py \
|
./transducer_stateless_modified-2/pretrained.py \
|
||||||
--checkpoint /path/to/pretrained.pt \
|
--checkpoint /path/to/pretrained.pt \
|
||||||
--lang-dir /path/to/lang_char \
|
--lang-dir /path/to/lang_char \
|
||||||
@ -36,7 +36,7 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
# modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless_modified-2/pretrained.py \
|
./transducer_stateless_modified-2/pretrained.py \
|
||||||
--checkpoint /path/to/pretrained.pt \
|
--checkpoint /path/to/pretrained.pt \
|
||||||
--lang-dir /path/to/lang_char \
|
--lang-dir /path/to/lang_char \
|
||||||
@ -45,6 +45,14 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless_modified-2/pretrained.py \
|
||||||
|
--checkpoint /path/to/pretrained.pt \
|
||||||
|
--lang-dir /path/to/lang_char \
|
||||||
|
--method fast_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -53,11 +61,13 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -97,6 +107,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,7 +132,33 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search and modified_beam_search",
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -134,11 +171,10 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=1,
|
||||||
help="Maximum number of symbols per frame. "
|
help="Maximum number of symbols per frame. "
|
||||||
"Use only when --method is greedy_search",
|
"Use only when --method is greedy_search",
|
||||||
)
|
)
|
||||||
return parser
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -225,20 +261,37 @@ def main():
|
|||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=features, x_lens=feature_lens
|
x=features, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_waves = encoder_out.size(0)
|
||||||
hyp_list = []
|
hyp_list = []
|
||||||
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
logging.info(f"Using {params.method}")
|
||||||
|
|
||||||
|
if params.method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
hyp_list = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_list = greedy_search_batch(
|
hyp_list = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_list = modified_beam_search(
|
hyp_list = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -19,18 +19,18 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless_modified/decode.py \
|
./transducer_stateless_modified/decode.py \
|
||||||
--epoch 64 \
|
--epoch 14 \
|
||||||
--avg 33 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless_modified/exp \
|
--exp-dir ./transducer_stateless_modified/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./transducer_stateless_modified/decode.py \
|
./transducer_stateless_modified/decode.py \
|
||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless_modified/exp \
|
--exp-dir ./transducer_stateless_modified/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,28 +39,43 @@ Usage:
|
|||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless_modified/exp \
|
--exp-dir ./transducer_stateless_modified/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless_modified/decode.py \
|
||||||
|
--epoch 14 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir ./transducer_stateless_modified/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import (
|
||||||
from conformer import Conformer
|
beam_search,
|
||||||
from decoder import Decoder
|
fast_beam_search_one_best,
|
||||||
from joiner import Joiner
|
greedy_search,
|
||||||
from model import Transducer
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -113,6 +128,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,7 +136,35 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --decoding-method is beam_search",
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -130,84 +174,24 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=1,
|
||||||
help="Maximum number of symbols per frame",
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# parameters for conformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"encoder_out_dim": 512,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
|
||||||
encoder = Conformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_dim=params.encoder_out_dim,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
embedding_dim=params.encoder_out_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
|
||||||
joiner = Joiner(
|
|
||||||
input_dim=params.encoder_out_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = Transducer(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
token_table: k2.SymbolTable,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -228,8 +212,11 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
lexicon:
|
token_table:
|
||||||
It contains the token symbol table and the word symbol table.
|
It maps token ID to a string.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -247,9 +234,36 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
hyps = []
|
|
||||||
batch_size = encoder_out.size(0)
|
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hyp_tokens = []
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
@ -262,29 +276,38 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model,
|
||||||
)
|
encoder_out=encoder_out_i,
|
||||||
elif params.decoding_method == "modified_beam_search":
|
beam=params.beam_size,
|
||||||
hyp = modified_beam_search(
|
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append([lexicon.token_table[i] for i in hyp])
|
hyp_tokens.append(hyp)
|
||||||
|
|
||||||
|
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
token_table: k2.SymbolTable,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -295,6 +318,11 @@ def decode_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
It maps a token ID to a string.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -310,9 +338,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -321,7 +349,8 @@ def decode_dataset(
|
|||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
token_table=token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -356,6 +385,7 @@ def save_results(
|
|||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
@ -406,13 +436,21 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -452,6 +490,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -467,7 +510,8 @@ def main():
|
|||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
token_table=lexicon.token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
@ -479,8 +523,5 @@ def main():
|
|||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
# greedy search
|
(1) greedy search
|
||||||
./transducer_stateless_modified/pretrained.py \
|
./transducer_stateless_modified/pretrained.py \
|
||||||
--checkpoint /path/to/pretrained.pt \
|
--checkpoint /path/to/pretrained.pt \
|
||||||
--lang-dir /path/to/lang_char \
|
--lang-dir /path/to/lang_char \
|
||||||
@ -27,7 +27,7 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
# beam search
|
(2) beam search
|
||||||
./transducer_stateless_modified/pretrained.py \
|
./transducer_stateless_modified/pretrained.py \
|
||||||
--checkpoint /path/to/pretrained.pt \
|
--checkpoint /path/to/pretrained.pt \
|
||||||
--lang-dir /path/to/lang_char \
|
--lang-dir /path/to/lang_char \
|
||||||
@ -36,7 +36,7 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
# modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless_modified/pretrained.py \
|
./transducer_stateless_modified/pretrained.py \
|
||||||
--checkpoint /path/to/pretrained.pt \
|
--checkpoint /path/to/pretrained.pt \
|
||||||
--lang-dir /path/to/lang_char \
|
--lang-dir /path/to/lang_char \
|
||||||
@ -45,6 +45,14 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless_modified/pretrained.py \
|
||||||
|
--checkpoint /path/to/pretrained.pt \
|
||||||
|
--lang-dir /path/to/lang_char \
|
||||||
|
--method fast_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -53,11 +61,13 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -97,6 +107,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,7 +132,33 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search and modified_beam_search",
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -134,11 +171,10 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=1,
|
||||||
help="Maximum number of symbols per frame. "
|
help="Maximum number of symbols per frame. "
|
||||||
"Use only when --method is greedy_search",
|
"Use only when --method is greedy_search",
|
||||||
)
|
)
|
||||||
return parser
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -225,20 +261,37 @@ def main():
|
|||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=features, x_lens=feature_lens
|
x=features, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_waves = encoder_out.size(0)
|
||||||
hyp_list = []
|
hyp_list = []
|
||||||
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
logging.info(f"Using {params.method}")
|
||||||
|
|
||||||
|
if params.method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
hyp_list = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_list = greedy_search_batch(
|
hyp_list = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_list = modified_beam_search(
|
hyp_list = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -27,6 +27,149 @@ from icefall.decode import Nbest, one_best_decoding
|
|||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_one_best(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using modified beam search, and then
|
||||||
|
the shortest path within the lattice is used as the final output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(lattice)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_nbest_oracle(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
num_paths: int,
|
||||||
|
ref_texts: List[List[int]],
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using modified beam search, and then
|
||||||
|
we select `num_paths` linear paths from the lattice. The path
|
||||||
|
that has the minimum edit distance with the given reference transcript
|
||||||
|
is used as the output.
|
||||||
|
|
||||||
|
This is the best result we can achieve for any nbest based rescoring
|
||||||
|
methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
ref_texts:
|
||||||
|
A list-of-list of integers containing the reference transcripts.
|
||||||
|
If the decoding_graph is a trivial_graph, the integer ID is the
|
||||||
|
BPE token ID.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = nbest.build_levenshtein_graphs()
|
||||||
|
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
|
||||||
|
|
||||||
|
levenshtein_alignment = k2.levenshtein_alignment(
|
||||||
|
refs=refs,
|
||||||
|
hyps=hyps,
|
||||||
|
hyp_to_ref_map=nbest.shape.row_ids(1),
|
||||||
|
sorted_match_ref=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tot_scores = levenshtein_alignment.get_tot_scores(
|
||||||
|
use_double_scores=False, log_semiring=False
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search(
|
def fast_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
decoding_graph: k2.Fsa,
|
decoding_graph: k2.Fsa,
|
||||||
@ -35,8 +178,7 @@ def fast_beam_search(
|
|||||||
beam: float,
|
beam: float,
|
||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
use_max: bool = False,
|
) -> k2.Fsa:
|
||||||
) -> List[List[int]]:
|
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -55,11 +197,10 @@ def fast_beam_search(
|
|||||||
Max states per stream per frame.
|
Max states per stream per frame.
|
||||||
max_contexts:
|
max_contexts:
|
||||||
Max contexts pre stream per frame.
|
Max contexts pre stream per frame.
|
||||||
use_max:
|
|
||||||
True to use max operation to select the hypothesis with the largest
|
|
||||||
log_prob when there are duplicate hypotheses; False to use log-add.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return an FsaVec with axes [utt][state][arc] containing the decoded
|
||||||
|
lattice. Note: When the input graph is a TrivialGraph, the returned
|
||||||
|
lattice is actually an acceptor.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
@ -92,7 +233,7 @@ def fast_beam_search(
|
|||||||
# (shape.NumElements(), 1, encoder_out_dim)
|
# (shape.NumElements(), 1, encoder_out_dim)
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = torch.index_select(
|
current_encoder_out = torch.index_select(
|
||||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).long()
|
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||||
# in some old versions of pytorch, the type of index requires
|
# in some old versions of pytorch, the type of index requires
|
||||||
# to be LongTensor. In the newest version of pytorch, the type
|
# to be LongTensor. In the newest version of pytorch, the type
|
||||||
# of index can be IntTensor or LongTensor. For supporting the
|
# of index can be IntTensor or LongTensor. For supporting the
|
||||||
@ -109,67 +250,7 @@ def fast_beam_search(
|
|||||||
decoding_streams.terminate_and_flush_to_streams()
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
|
|
||||||
if use_max:
|
return lattice
|
||||||
best_path = one_best_decoding(lattice)
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
return hyps
|
|
||||||
else:
|
|
||||||
num_paths = 200
|
|
||||||
use_double_scores = True
|
|
||||||
nbest_scale = 0.8
|
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
|
||||||
lattice=lattice,
|
|
||||||
num_paths=num_paths,
|
|
||||||
use_double_scores=use_double_scores,
|
|
||||||
nbest_scale=nbest_scale,
|
|
||||||
)
|
|
||||||
# The following code is modified from nbest.intersect()
|
|
||||||
word_fsa = k2.invert(nbest.fsa)
|
|
||||||
if hasattr(lattice, "aux_labels"):
|
|
||||||
# delete token IDs as it is not needed
|
|
||||||
del word_fsa.aux_labels
|
|
||||||
word_fsa.scores.zero_()
|
|
||||||
|
|
||||||
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
|
||||||
path_to_utt_map = nbest.shape.row_ids(1)
|
|
||||||
|
|
||||||
if hasattr(lattice, "aux_labels"):
|
|
||||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
|
||||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
|
||||||
inv_lattice = k2.invert(lattice)
|
|
||||||
inv_lattice = k2.arc_sort(inv_lattice)
|
|
||||||
else:
|
|
||||||
inv_lattice = k2.arc_sort(lattice)
|
|
||||||
|
|
||||||
if inv_lattice.shape[0] == 1:
|
|
||||||
path_lattice = k2.intersect_device(
|
|
||||||
inv_lattice,
|
|
||||||
word_fsa_with_epsilon_loops,
|
|
||||||
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
|
||||||
sorted_match_a=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
path_lattice = k2.intersect_device(
|
|
||||||
inv_lattice,
|
|
||||||
word_fsa_with_epsilon_loops,
|
|
||||||
b_to_a_map=path_to_utt_map,
|
|
||||||
sorted_match_a=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# path_lattice has word IDs as labels and token IDs as aux_labels
|
|
||||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
|
||||||
|
|
||||||
tot_scores = path_lattice.get_tot_scores(
|
|
||||||
use_double_scores=use_double_scores, log_semiring=True
|
|
||||||
)
|
|
||||||
|
|
||||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
|
||||||
best_hyp_indexes = ragged_tot_scores.argmax()
|
|
||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
return hyps
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
@ -193,10 +274,10 @@ def greedy_search(
|
|||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = model.decoder.unk_id
|
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||||
@ -230,7 +311,7 @@ def greedy_search(
|
|||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
if y != blank_id and y != unk_id:
|
if y not in (blank_id, unk_id):
|
||||||
hyp.append(y)
|
hyp.append(y)
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[hyp[-context_size:]], device=device
|
[hyp[-context_size:]], device=device
|
||||||
@ -249,7 +330,9 @@ def greedy_search(
|
|||||||
|
|
||||||
|
|
||||||
def greedy_search_batch(
|
def greedy_search_batch(
|
||||||
model: Transducer, encoder_out: torch.Tensor
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
Args:
|
Args:
|
||||||
@ -257,6 +340,9 @@ def greedy_search_batch(
|
|||||||
The transducer model.
|
The transducer model.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs containing the decoded results.
|
Return a list-of-list of token IDs containing the decoded results.
|
||||||
len(ans) equals to encoder_out.size(0).
|
len(ans) equals to encoder_out.size(0).
|
||||||
@ -264,28 +350,48 @@ def greedy_search_batch(
|
|||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
device = model.device
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
device = next(model.parameters()).device
|
||||||
T = encoder_out.size(1)
|
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = model.decoder.unk_id
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
) # (batch_size, context_size)
|
) # (N, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
# decoder_out: (N, 1, decoder_out_dim)
|
||||||
for t in range(T):
|
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
encoder_out = packed_encoder_out.data
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
|
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
|
||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
|
||||||
@ -294,12 +400,12 @@ def greedy_search_batch(
|
|||||||
y = logits.argmax(dim=1).tolist()
|
y = logits.argmax(dim=1).tolist()
|
||||||
emitted = False
|
emitted = False
|
||||||
for i, v in enumerate(y):
|
for i, v in enumerate(y):
|
||||||
if v != blank_id and v != unk_id:
|
if v not in (blank_id, unk_id):
|
||||||
hyps[i].append(v)
|
hyps[i].append(v)
|
||||||
emitted = True
|
emitted = True
|
||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
decoder_input = [h[-context_size:] for h in hyps]
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
device=device,
|
device=device,
|
||||||
@ -307,7 +413,12 @@ def greedy_search_batch(
|
|||||||
)
|
)
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -472,6 +583,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
|||||||
def modified_beam_search(
|
def modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
use_max: bool = False,
|
use_max: bool = False,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
@ -482,6 +594,9 @@ def modified_beam_search(
|
|||||||
The transducer model.
|
The transducer model.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C).
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
beam:
|
beam:
|
||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
use_max:
|
use_max:
|
||||||
@ -492,16 +607,27 @@ def modified_beam_search(
|
|||||||
for the i-th utterance.
|
for the i-th utterance.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
T = encoder_out.size(1)
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = model.decoder.unk_id
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
|
||||||
for i in range(batch_size):
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
B[i].add(
|
B[i].add(
|
||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
@ -510,9 +636,20 @@ def modified_beam_search(
|
|||||||
use_max=use_max,
|
use_max=use_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
for t in range(T):
|
encoder_out = packed_encoder_out.data
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
hyps_shape = _get_hyps_shape(B).to(device)
|
hyps_shape = _get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
@ -577,15 +714,21 @@ def modified_beam_search(
|
|||||||
|
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
if new_token != blank_id and new_token != unk_id:
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
|
|
||||||
new_log_prob = topk_log_probs[k]
|
new_log_prob = topk_log_probs[k]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
ans = [h.ys[context_size:] for h in best_hyps]
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@ -622,10 +765,10 @@ def _deprecated_modified_beam_search(
|
|||||||
# support only batch_size == 1 for now
|
# support only batch_size == 1 for now
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = model.decoder.unk_id
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
@ -691,7 +834,7 @@ def _deprecated_modified_beam_search(
|
|||||||
hyp = A[topk_hyp_indexes[i]]
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
new_token = topk_token_indexes[i]
|
new_token = topk_token_indexes[i]
|
||||||
if new_token != blank_id and new_token != unk_id:
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
new_log_prob = topk_log_probs[i]
|
new_log_prob = topk_log_probs[i]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
@ -732,10 +875,10 @@ def beam_search(
|
|||||||
# support only batch_size == 1 for now
|
# support only batch_size == 1 for now
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = model.decoder.unk_id
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size,
|
[blank_id] * context_size,
|
||||||
@ -818,7 +961,7 @@ def beam_search(
|
|||||||
# Second, process other non-blank labels
|
# Second, process other non-blank labels
|
||||||
values, indices = log_prob.topk(beam + 1)
|
values, indices = log_prob.topk(beam + 1)
|
||||||
for i, v in zip(indices.tolist(), values.tolist()):
|
for i, v in zip(indices.tolist(), values.tolist()):
|
||||||
if i == blank_id or i == unk_id:
|
if i in (blank_id, unk_id):
|
||||||
continue
|
continue
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
new_log_prob = y_star.log_prob + v
|
||||||
|
@ -22,15 +22,15 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless/decode.py \
|
./pruned_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
@ -61,7 +61,7 @@ Usage:
|
|||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--use-LG True \
|
--use-LG True \
|
||||||
--use-max False \
|
--use-max False \
|
||||||
--max-duration 1500 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 8 \
|
--beam 8 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
@ -82,7 +82,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -307,7 +307,7 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -315,7 +315,6 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
use_max=params.use_max,
|
|
||||||
)
|
)
|
||||||
if params.use_LG:
|
if params.use_LG:
|
||||||
for hyp in hyp_tokens:
|
for hyp in hyp_tokens:
|
||||||
@ -330,6 +329,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -337,6 +337,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
use_max=params.use_max,
|
use_max=params.use_max,
|
||||||
)
|
)
|
||||||
@ -421,9 +422,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
@ -25,7 +25,7 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav \
|
||||||
|
|
||||||
(1) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless/pretrained.py \
|
./pruned_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
@ -34,6 +34,24 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav \
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./pruned_transducer_stateless/pretrained.py \
|
||||||
|
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--method modified_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav \
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./pruned_transducer_stateless/pretrained.py \
|
||||||
|
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--method fast_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav \
|
||||||
|
|
||||||
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
|
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
|
||||||
|
|
||||||
Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by
|
Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by
|
||||||
@ -46,12 +64,14 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -77,9 +97,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="""Path to bpe.model.""",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -90,6 +108,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -114,7 +133,33 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search and modified_beam_search",
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -230,10 +275,25 @@ def main():
|
|||||||
if params.method == "beam_search":
|
if params.method == "beam_search":
|
||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
if params.method == "modified_beam_search":
|
|
||||||
|
if params.method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -243,6 +303,7 @@ def main():
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
@ -335,7 +335,9 @@ def greedy_search(
|
|||||||
|
|
||||||
|
|
||||||
def greedy_search_batch(
|
def greedy_search_batch(
|
||||||
model: Transducer, encoder_out: torch.Tensor
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
Args:
|
Args:
|
||||||
@ -343,6 +345,9 @@ def greedy_search_batch(
|
|||||||
The transducer model.
|
The transducer model.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs containing the decoded results.
|
Return a list-of-list of token IDs containing the decoded results.
|
||||||
len(ans) equals to encoder_out.size(0).
|
len(ans) equals to encoder_out.size(0).
|
||||||
@ -350,31 +355,49 @@ def greedy_search_batch(
|
|||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
device = next(model.parameters()).device
|
||||||
T = encoder_out.size(1)
|
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
) # (batch_size, context_size)
|
) # (N, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
# decoder_out: (N, 1, decoder_out_dim)
|
||||||
|
|
||||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
for t in range(T):
|
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
)
|
)
|
||||||
@ -390,7 +413,7 @@ def greedy_search_batch(
|
|||||||
emitted = True
|
emitted = True
|
||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
decoder_input = [h[-context_size:] for h in hyps]
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
device=device,
|
device=device,
|
||||||
@ -399,7 +422,12 @@ def greedy_search_batch(
|
|||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -557,6 +585,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
|||||||
def modified_beam_search(
|
def modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
@ -566,6 +595,9 @@ def modified_beam_search(
|
|||||||
The transducer model.
|
The transducer model.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C).
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
beam:
|
beam:
|
||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
Returns:
|
Returns:
|
||||||
@ -573,16 +605,27 @@ def modified_beam_search(
|
|||||||
for the i-th utterance.
|
for the i-th utterance.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
T = encoder_out.size(1)
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
|
||||||
for i in range(batch_size):
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
B[i].add(
|
B[i].add(
|
||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
@ -590,11 +633,20 @@ def modified_beam_search(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
for t in range(T):
|
offset = 0
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
finalized_B = []
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
hyps_shape = _get_hyps_shape(B).to(device)
|
hyps_shape = _get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
@ -668,8 +720,14 @@ def modified_beam_search(
|
|||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
ans = [h.ys[context_size:] for h in best_hyps]
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
@ -22,15 +22,15 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
@ -270,6 +270,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -277,6 +278,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -356,9 +358,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
@ -22,15 +22,15 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless3/decode-giga.py \
|
./pruned_transducer_stateless3/decode-giga.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
@ -224,8 +224,8 @@ def get_parser():
|
|||||||
|
|
||||||
|
|
||||||
def post_processing(
|
def post_processing(
|
||||||
results: List[Tuple[List[List[str]], List[List[str]]]],
|
results: List[Tuple[List[str], List[str]]],
|
||||||
) -> List[Tuple[List[List[str]], List[List[str]]]]:
|
) -> List[Tuple[List[str], List[str]]]:
|
||||||
new_results = []
|
new_results = []
|
||||||
for ref, hyp in results:
|
for ref, hyp in results:
|
||||||
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||||
@ -415,9 +415,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
@ -22,15 +22,15 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
@ -307,6 +307,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -314,6 +315,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -403,9 +405,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
@ -22,16 +22,16 @@ Usage:
|
|||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,8 +39,8 @@ Usage:
|
|||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -48,8 +48,8 @@ Usage:
|
|||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
@ -70,7 +70,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -266,7 +266,7 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -284,6 +284,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -291,6 +292,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -370,9 +372,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
@ -22,6 +22,235 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_one_best(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using modified beam search, and then
|
||||||
|
the shortest path within the lattice is used as the final output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(lattice)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_nbest_oracle(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
num_paths: int,
|
||||||
|
ref_texts: List[List[int]],
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using modified beam search, and then
|
||||||
|
we select `num_paths` linear paths from the lattice. The path
|
||||||
|
that has the minimum edit distance with the given reference transcript
|
||||||
|
is used as the output.
|
||||||
|
|
||||||
|
This is the best result we can achieve for any nbest based rescoring
|
||||||
|
methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
ref_texts:
|
||||||
|
A list-of-list of integers containing the reference transcripts.
|
||||||
|
If the decoding_graph is a trivial_graph, the integer ID is the
|
||||||
|
BPE token ID.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = nbest.build_levenshtein_graphs()
|
||||||
|
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
|
||||||
|
|
||||||
|
levenshtein_alignment = k2.levenshtein_alignment(
|
||||||
|
refs=refs,
|
||||||
|
hyps=hyps,
|
||||||
|
hyp_to_ref_map=nbest.shape.row_ids(1),
|
||||||
|
sorted_match_ref=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tot_scores = levenshtein_alignment.get_tot_scores(
|
||||||
|
use_double_scores=False, log_semiring=False
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
Returns:
|
||||||
|
Return an FsaVec with axes [utt][state][arc] containing the decoded
|
||||||
|
lattice. Note: When the input graph is a TrivialGraph, the returned
|
||||||
|
lattice is actually an acceptor.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
vocab_size = model.decoder.vocab_size
|
||||||
|
|
||||||
|
B, T, C = encoder_out.shape
|
||||||
|
|
||||||
|
config = k2.RnntDecodingConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
decoder_history_len=context_size,
|
||||||
|
beam=beam,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
max_states=max_states,
|
||||||
|
)
|
||||||
|
individual_streams = []
|
||||||
|
for i in range(B):
|
||||||
|
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
|
||||||
|
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||||
|
|
||||||
|
encoder_out_len = torch.ones(1, dtype=torch.int32)
|
||||||
|
decoder_out_len = torch.ones(1, dtype=torch.int32)
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
# shape is a RaggedShape of shape (B, context)
|
||||||
|
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||||
|
shape, contexts = decoding_streams.get_contexts()
|
||||||
|
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||||
|
contexts = contexts.to(torch.int64)
|
||||||
|
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||||
|
decoder_out = model.decoder(contexts, need_pad=False)
|
||||||
|
# current_encoder_out is of shape
|
||||||
|
# (shape.NumElements(), 1, joiner_dim)
|
||||||
|
# fmt: off
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
decoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
) # (N, vocab_size)
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
decoding_streams.advance(log_probs)
|
||||||
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
|
|
||||||
|
return lattice
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||||
@ -104,7 +333,9 @@ def greedy_search(
|
|||||||
|
|
||||||
|
|
||||||
def greedy_search_batch(
|
def greedy_search_batch(
|
||||||
model: Transducer, encoder_out: torch.Tensor
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
Args:
|
Args:
|
||||||
@ -112,6 +343,9 @@ def greedy_search_batch(
|
|||||||
The transducer model.
|
The transducer model.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs containing the decoded results.
|
Return a list-of-list of token IDs containing the decoded results.
|
||||||
len(ans) equals to encoder_out.size(0).
|
len(ans) equals to encoder_out.size(0).
|
||||||
@ -119,32 +353,54 @@ def greedy_search_batch(
|
|||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
device = model.device
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
device = next(model.parameters()).device
|
||||||
T = encoder_out.size(1)
|
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
) # (batch_size, context_size)
|
) # (N, context_size)
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
# decoder_out: (N, 1, decoder_out_dim)
|
||||||
|
|
||||||
encoder_out_len = torch.ones(batch_size, dtype=torch.int32)
|
encoder_out_len = torch.ones(1, dtype=torch.int32)
|
||||||
decoder_out_len = torch.ones(batch_size, dtype=torch.int32)
|
decoder_out_len = torch.ones(1, dtype=torch.int32)
|
||||||
|
|
||||||
for t in range(T):
|
encoder_out = packed_encoder_out.data
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
|
||||||
|
offset = 0
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1)
|
||||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len.expand(batch_size),
|
||||||
|
decoder_out_len.expand(batch_size),
|
||||||
) # (batch_size, vocab_size)
|
) # (batch_size, vocab_size)
|
||||||
|
|
||||||
assert logits.ndim == 2, logits.shape
|
assert logits.ndim == 2, logits.shape
|
||||||
@ -157,7 +413,7 @@ def greedy_search_batch(
|
|||||||
|
|
||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
decoder_input = [h[-context_size:] for h in hyps]
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
device=device,
|
device=device,
|
||||||
@ -168,7 +424,12 @@ def greedy_search_batch(
|
|||||||
need_pad=False,
|
need_pad=False,
|
||||||
) # (batch_size, 1, decoder_out_dim)
|
) # (batch_size, 1, decoder_out_dim)
|
||||||
|
|
||||||
ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -415,6 +676,7 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
|||||||
def modified_beam_search(
|
def modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
|
||||||
@ -424,6 +686,9 @@ def modified_beam_search(
|
|||||||
The transducer model.
|
The transducer model.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C).
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
beam:
|
beam:
|
||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
Returns:
|
Returns:
|
||||||
@ -431,15 +696,26 @@ def modified_beam_search(
|
|||||||
for the i-th utterance.
|
for the i-th utterance.
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
T = encoder_out.size(1)
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
|
||||||
for i in range(batch_size):
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
B[i].add(
|
B[i].add(
|
||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
@ -449,9 +725,20 @@ def modified_beam_search(
|
|||||||
|
|
||||||
encoder_out_len = torch.tensor([1])
|
encoder_out_len = torch.tensor([1])
|
||||||
decoder_out_len = torch.tensor([1])
|
decoder_out_len = torch.tensor([1])
|
||||||
for t in range(T):
|
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
encoder_out = packed_encoder_out.data
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for batch_size in batch_size_list:
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1)
|
||||||
# current_encoder_out's shape is: (batch_size, 1, encoder_out_dim)
|
# current_encoder_out's shape is: (batch_size, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
hyps_shape = _get_hyps_shape(B).to(device)
|
hyps_shape = _get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
@ -524,8 +811,14 @@ def modified_beam_search(
|
|||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
ans = [h.ys[context_size:] for h in best_hyps]
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
@ -22,15 +22,15 @@ Usage:
|
|||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./transducer_stateless/decode.py \
|
./transducer_stateless/decode.py \
|
||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,9 +39,20 @@ Usage:
|
|||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless/decode.py \
|
||||||
|
--epoch 14 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir ./transducer_stateless/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -49,14 +60,16 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -115,6 +128,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -122,8 +136,35 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
beam_search or modified_beam_search""",
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -149,6 +190,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -171,6 +213,9 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -188,24 +233,44 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
hyp_list: List[List[int]] = []
|
|
||||||
|
|
||||||
if (
|
hyps = []
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
):
|
):
|
||||||
hyp_list = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_list = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
@ -226,14 +291,20 @@ def decode_one_batch(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyp_list.append(hyp)
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -241,6 +312,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -253,6 +325,9 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -268,9 +343,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -280,6 +355,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -360,13 +436,21 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -408,6 +492,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -428,6 +517,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -58,6 +58,7 @@ class Decoder(nn.Module):
|
|||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
|
@ -24,7 +24,7 @@ Usage:
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
@ -33,7 +33,7 @@ Usage:
|
|||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
@ -42,7 +42,16 @@ Usage:
|
|||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless/pretrained.py \
|
||||||
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--method fast_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
|
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
|
||||||
|
|
||||||
@ -56,12 +65,14 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -87,9 +98,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="""Path to bpe.model.""",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -100,6 +109,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -124,7 +134,33 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search and modified_beam_search ",
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -241,15 +277,28 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
if params.method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
hyp_list = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_list = greedy_search_batch(
|
hyp_list = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_list = modified_beam_search(
|
hyp_list = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -22,15 +22,15 @@ Usage:
|
|||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless2/exp \
|
--exp-dir ./transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./transducer_stateless2/decode.py \
|
./transducer_stateless2/decode.py \
|
||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless2/exp \
|
--exp-dir ./transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,9 +39,20 @@ Usage:
|
|||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless2/exp \
|
--exp-dir ./transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless2/decode.py \
|
||||||
|
--epoch 14 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir ./transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -49,14 +60,16 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -115,6 +128,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -122,8 +136,35 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
beam_search or modified_beam_search""",
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -149,6 +190,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -171,6 +213,9 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -188,24 +233,44 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
hyp_list: List[List[int]] = []
|
|
||||||
|
|
||||||
if (
|
hyps = []
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
):
|
):
|
||||||
hyp_list = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_list = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
@ -226,14 +291,20 @@ def decode_one_batch(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyp_list.append(hyp)
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -241,6 +312,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -253,6 +325,9 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -268,9 +343,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -280,6 +355,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -360,13 +436,21 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -408,6 +492,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -428,6 +517,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -24,7 +24,7 @@ Usage:
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
@ -33,7 +33,7 @@ Usage:
|
|||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
@ -42,7 +42,16 @@ Usage:
|
|||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless2/pretrained.py \
|
||||||
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--method fast_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
You can also use `./transducer_stateless2/exp/epoch-xx.pt`.
|
You can also use `./transducer_stateless2/exp/epoch-xx.pt`.
|
||||||
|
|
||||||
@ -56,12 +65,14 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -87,9 +98,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="""Path to bpe.model.""",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -100,6 +109,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -124,7 +134,33 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search and modified_beam_search ",
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -241,15 +277,28 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
if params.method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
hyp_list = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_list = greedy_search_batch(
|
hyp_list = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_list = modified_beam_search(
|
hyp_list = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -22,17 +22,37 @@ Usage:
|
|||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search (not recommended)
|
||||||
./transducer_stateless_multi_datasets/decode.py \
|
./transducer_stateless_multi_datasets/decode.py \
|
||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./transducer_stateless_multi_datasets/decode.py \
|
||||||
|
--epoch 14 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless_multi_datasets/decode.py \
|
||||||
|
--epoch 14 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -40,14 +60,16 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -107,6 +129,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -114,8 +137,35 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
beam_search or modified_beam_search""",
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -141,6 +191,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -163,6 +214,9 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -180,24 +234,44 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
hyp_list = []
|
|
||||||
batch_size = encoder_out.size(0)
|
|
||||||
|
|
||||||
if (
|
hyps = []
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
):
|
):
|
||||||
hyp_list = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_list = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
else:
|
else:
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
@ -218,14 +292,20 @@ def decode_one_batch(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyp_list.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -233,6 +313,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -245,6 +326,9 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -260,9 +344,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 100
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 10
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -272,6 +356,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -352,13 +437,21 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -402,6 +495,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -423,6 +521,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -44,6 +44,15 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--method fast_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
You can also use `./transducer_stateless_multi_datasets/exp/epoch-xx.pt`.
|
You can also use `./transducer_stateless_multi_datasets/exp/epoch-xx.pt`.
|
||||||
|
|
||||||
Note: ./transducer_stateless_multi_datasets/exp/pretrained.pt is generated by
|
Note: ./transducer_stateless_multi_datasets/exp/pretrained.pt is generated by
|
||||||
@ -56,12 +65,14 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -87,9 +98,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="""Path to bpe.model.""",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -100,6 +109,7 @@ def get_parser():
|
|||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -124,7 +134,33 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search and modified_beam_search ",
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -241,18 +277,30 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
if params.method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
hyp_list = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_list = greedy_search_batch(
|
hyp_list = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_list = modified_beam_search(
|
hyp_list = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
@ -69,7 +69,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import TedLiumAsrDataModule
|
from asr_datamodule import TedLiumAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -237,7 +237,7 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -255,6 +255,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -262,6 +263,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
@ -72,23 +72,16 @@ import k2
|
|||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from conformer import Conformer
|
|
||||||
from decoder import Decoder
|
|
||||||
from joiner import Joiner
|
|
||||||
from model import Transducer
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.utils import AttributeDict
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -185,76 +178,16 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
"sample_rate": 16000,
|
|
||||||
# parameters for conformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
# parameters for decoder
|
|
||||||
"embedding_dim": 512,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = Conformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
embedding_dim=params.embedding_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
unk_id=params.unk_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|
||||||
joiner = Joiner(
|
|
||||||
input_dim=params.vocab_size,
|
|
||||||
inner_dim=params.embedding_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = Transducer(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def read_sound_files(
|
def read_sound_files(
|
||||||
filenames: List[str], expected_sample_rate: float
|
filenames: List[str], expected_sample_rate: float
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
@ -354,7 +287,7 @@ def main():
|
|||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -372,6 +305,7 @@ def main():
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -379,6 +313,7 @@ def main():
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user