From f23dd437196178e11de9f1b283053a3a9565f63d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 14 May 2022 21:45:39 +0800 Subject: [PATCH] Update results for libri+giga multi dataset setup. (#363) * Update results for libri+giga multi dataset setup. --- ...pruned-transducer-stateless3-2022-05-13.sh | 80 ++++++++++ .../workflows/run-librispeech-2022-04-29.yml | 8 +- ...runed-transducer-stateless3-2022-05-13.yml | 151 ++++++++++++++++++ README.md | 2 +- egs/librispeech/ASR/RESULTS.md | 63 +++++++- .../pruned_transducer_stateless/pretrained.py | 49 +++--- .../pretrained.py | 82 +++++++--- .../pretrained.py | 82 +++++++--- .../ASR/pruned_transducer_stateless3/train.py | 1 + 9 files changed, 451 insertions(+), 67 deletions(-) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh create mode 100644 .github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh new file mode 100755 index 000000000..3617bc369 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt +popd + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless3/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless3/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless3/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless3/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless3/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless3/exp + done + + rm pruned_transducer_stateless3/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml index e3fe3b904..6c8188b48 100644 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ b/.github/workflows/run-librispeech-2022-04-29.yml @@ -142,8 +142,8 @@ jobs: find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Display decoding results for pruned_transducer_stateless3 if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' @@ -161,8 +161,8 @@ jobs: find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for pruned_transducer_stateless2 uses: actions/upload-artifact@v2 diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml new file mode 100644 index 000000000..512f1b334 --- /dev/null +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -0,0 +1,151 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-pruned-transducer-stateless3-2022-05-13 +# stateless pruned transducer (reworked model) + giga speech + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_pruned_transducer_stateless3_2022_05_13: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }} + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh + + - name: Display decoding results for pruned_transducer_stateless3 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR + tree pruned_transducer_stateless3/exp + cd pruned_transducer_stateless3/exp + echo "===greedy search===" + find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for pruned_transducer_stateless3 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29 + path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/ diff --git a/README.md b/README.md index c4dad6aaf..47bf0e212 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles | | test-clean | test-other | |-----|------------|------------| -| WER | 2.19 | 4.97 | +| WER | 2.00 | 4.63 | ### Aishell diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 3143fa077..874700f11 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,6 +1,6 @@ ## Results -### LibriSpeech BPE training results (Pruned Transducer 3) +### LibriSpeech BPE training results (Pruned Transducer 3, 2022-04-29) [pruned_transducer_stateless3](./pruned_transducer_stateless3) Same as `Pruned Transducer 2` but using the XL subset from @@ -152,6 +152,67 @@ for epoch in 27; do done ``` +### LibriSpeech BPE training results (Pruned Transducer 3, 2022-05-13) + +Same setup as [pruned_transducer_stateless3](./pruned_transducer_stateless3) (2022-04-29) +but change `--giga-prob` from 0.8 to 0.9. Also use `repeat` on gigaspeech XL +subset so that the gigaspeech dataloader never exhausts. + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|---------------------------------------------| +| greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 | +| fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 | + +The training commands are: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./prepare.sh +./prepare_giga_speech.sh + +./pruned_transducer_stateless3/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 0 \ + --full-libri 1 \ + --exp-dir pruned_transducer_stateless3/exp-0.9 \ + --max-duration 300 \ + --use-fp16 1 \ + --lr-epochs 4 \ + --num-workers 2 \ + --giga-prob 0.9 +``` + +The tensorboard log is available at + + +Decoding commands: + +```bash +for iter in 1224000; do + for avg in 14; do + for method in greedy_search modified_beam_search fast_beam_search ; do + ./pruned_transducer_stateless3/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --beam 4 \ + --max-contexts 32 + done + done +done +``` + +The pretrained models, training logs, decoding logs, and decoding results +can be found at + + + ### LibriSpeech BPE training results (Pruned Transducer 2) [pruned_transducer_stateless2](./pruned_transducer_stateless2) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 148bf7b02..d21a737b8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -19,38 +19,38 @@ Usage: (1) greedy search ./pruned_transducer_stateless/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav (2) beam search ./pruned_transducer_stateless/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav (3) modified beam search ./pruned_transducer_stateless/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav (4) fast beam search ./pruned_transducer_stateless/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`. @@ -233,6 +233,9 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + checkpoint = torch.load(args.checkpoint, map_location="cpu") model.load_state_dict(checkpoint["model"], strict=False) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index bcafe68d6..21bcf7cfd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -19,20 +19,38 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav -(1) beam search +(2) beam search ./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. @@ -79,9 +97,7 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="""Path to bpe.model.""", ) parser.add_argument( @@ -117,7 +133,33 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --method is beam_search and modified_beam_search", + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", ) parser.add_argument( @@ -244,9 +286,9 @@ def main(): decoding_graph=decoding_graph, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - beam=8.0, - max_contexts=32, - max_states=8, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -254,6 +296,7 @@ def main(): hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) @@ -263,6 +306,7 @@ def main(): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index d0fe5d24e..7efa592f9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -19,20 +19,38 @@ Usage: (1) greedy search ./pruned_transducer_stateless3/pretrained.py \ - --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav \ + --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav -(1) beam search +(2) beam search ./pruned_transducer_stateless3/pretrained.py \ - --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav \ + --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav You can also use `./pruned_transducer_stateless3/exp/epoch-xx.pt`. @@ -79,9 +97,7 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="""Path to bpe.model.""", ) parser.add_argument( @@ -117,7 +133,33 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --method is beam_search and modified_beam_search", + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", ) parser.add_argument( @@ -244,9 +286,9 @@ def main(): decoding_graph=decoding_graph, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - beam=8.0, - max_contexts=32, - max_states=8, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -254,6 +296,7 @@ def main(): hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) @@ -263,6 +306,7 @@ def main(): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 4966ea57f..037f99bc7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -968,6 +968,7 @@ def run(rank, world_size, args): train_giga_cuts = gigaspeech.train_S_cuts() train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: cuts_musan = load_manifest(