diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
new file mode 100755
index 000000000..3617bc369
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
@@ -0,0 +1,80 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
+popd
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless3/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless3/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless3/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless3/exp
+ done
+
+ rm pruned_transducer_stateless3/exp/*.pt
+fi
diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml
index e3fe3b904..6c8188b48 100644
--- a/.github/workflows/run-librispeech-2022-04-29.yml
+++ b/.github/workflows/run-librispeech-2022-04-29.yml
@@ -142,8 +142,8 @@ jobs:
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Display decoding results for pruned_transducer_stateless3
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
@@ -161,8 +161,8 @@ jobs:
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for pruned_transducer_stateless2
uses: actions/upload-artifact@v2
diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
new file mode 100644
index 000000000..512f1b334
--- /dev/null
+++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
@@ -0,0 +1,151 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-pruned-transducer-stateless3-2022-05-13
+# stateless pruned transducer (reworked model) + giga speech
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_pruned_transducer_stateless3_2022_05_13:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
+
+ - name: Display decoding results for pruned_transducer_stateless3
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR
+ tree pruned_transducer_stateless3/exp
+ cd pruned_transducer_stateless3/exp
+ echo "===greedy search==="
+ find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for pruned_transducer_stateless3
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
+ path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
diff --git a/README.md b/README.md
index c4dad6aaf..47bf0e212 100644
--- a/README.md
+++ b/README.md
@@ -107,7 +107,7 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
| | test-clean | test-other |
|-----|------------|------------|
-| WER | 2.19 | 4.97 |
+| WER | 2.00 | 4.63 |
### Aishell
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 3143fa077..874700f11 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1,6 +1,6 @@
## Results
-### LibriSpeech BPE training results (Pruned Transducer 3)
+### LibriSpeech BPE training results (Pruned Transducer 3, 2022-04-29)
[pruned_transducer_stateless3](./pruned_transducer_stateless3)
Same as `Pruned Transducer 2` but using the XL subset from
@@ -152,6 +152,67 @@ for epoch in 27; do
done
```
+### LibriSpeech BPE training results (Pruned Transducer 3, 2022-05-13)
+
+Same setup as [pruned_transducer_stateless3](./pruned_transducer_stateless3) (2022-04-29)
+but change `--giga-prob` from 0.8 to 0.9. Also use `repeat` on gigaspeech XL
+subset so that the gigaspeech dataloader never exhausts.
+
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|---------------------------------------------|
+| greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 |
+| modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 |
+| fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 |
+
+The training commands are:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./prepare.sh
+./prepare_giga_speech.sh
+
+./pruned_transducer_stateless3/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --full-libri 1 \
+ --exp-dir pruned_transducer_stateless3/exp-0.9 \
+ --max-duration 300 \
+ --use-fp16 1 \
+ --lr-epochs 4 \
+ --num-workers 2 \
+ --giga-prob 0.9
+```
+
+The tensorboard log is available at
+
+
+Decoding commands:
+
+```bash
+for iter in 1224000; do
+ for avg in 14; do
+ for method in greedy_search modified_beam_search fast_beam_search ; do
+ ./pruned_transducer_stateless3/decode.py \
+ --iter $iter \
+ --avg $avg \
+ --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \
+ --max-duration 600 \
+ --decoding-method $method \
+ --max-sym-per-frame 1 \
+ --beam 4 \
+ --max-contexts 32
+ done
+ done
+done
+```
+
+The pretrained models, training logs, decoding logs, and decoding results
+can be found at
+
+
+
### LibriSpeech BPE training results (Pruned Transducer 2)
[pruned_transducer_stateless2](./pruned_transducer_stateless2)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
index 148bf7b02..d21a737b8 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
@@ -19,38 +19,38 @@ Usage:
(1) greedy search
./pruned_transducer_stateless/pretrained.py \
- --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method greedy_search \
- /path/to/foo.wav \
- /path/to/bar.wav \
+ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
(2) beam search
./pruned_transducer_stateless/pretrained.py \
- --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method beam_search \
- --beam-size 4 \
- /path/to/foo.wav \
- /path/to/bar.wav \
+ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless/pretrained.py \
- --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method modified_beam_search \
- --beam-size 4 \
- /path/to/foo.wav \
- /path/to/bar.wav \
+ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless/pretrained.py \
- --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method fast_beam_search \
- --beam-size 4 \
- /path/to/foo.wav \
- /path/to/bar.wav \
+ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`.
@@ -233,6 +233,9 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
index bcafe68d6..21bcf7cfd 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -19,20 +19,38 @@ Usage:
(1) greedy search
./pruned_transducer_stateless2/pretrained.py \
- --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method greedy_search \
- /path/to/foo.wav \
- /path/to/bar.wav \
+ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
-(1) beam search
+(2) beam search
./pruned_transducer_stateless2/pretrained.py \
- --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method beam_search \
- --beam-size 4 \
- /path/to/foo.wav \
- /path/to/bar.wav \
+ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless2/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless2/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
@@ -79,9 +97,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
- help="""Path to bpe.model.
- Used only when method is ctc-decoding.
- """,
+ help="""Path to bpe.model.""",
)
parser.add_argument(
@@ -117,7 +133,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
- help="Used only when --method is beam_search and modified_beam_search",
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@@ -244,9 +286,9 @@ def main():
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
- beam=8.0,
- max_contexts=32,
- max_states=8,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@@ -254,6 +296,7 @@ def main():
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
@@ -263,6 +306,7 @@ def main():
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
index d0fe5d24e..7efa592f9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
@@ -19,20 +19,38 @@ Usage:
(1) greedy search
./pruned_transducer_stateless3/pretrained.py \
- --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method greedy_search \
- /path/to/foo.wav \
- /path/to/bar.wav \
+ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
-(1) beam search
+(2) beam search
./pruned_transducer_stateless3/pretrained.py \
- --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method beam_search \
- --beam-size 4 \
- /path/to/foo.wav \
- /path/to/bar.wav \
+ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless3/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless3/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
You can also use `./pruned_transducer_stateless3/exp/epoch-xx.pt`.
@@ -79,9 +97,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
- help="""Path to bpe.model.
- Used only when method is ctc-decoding.
- """,
+ help="""Path to bpe.model.""",
)
parser.add_argument(
@@ -117,7 +133,33 @@ def get_parser():
"--beam-size",
type=int,
default=4,
- help="Used only when --method is beam_search and modified_beam_search",
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
@@ -244,9 +286,9 @@ def main():
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
- beam=8.0,
- max_contexts=32,
- max_states=8,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@@ -254,6 +296,7 @@ def main():
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
@@ -263,6 +306,7 @@ def main():
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
index 4966ea57f..037f99bc7 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
@@ -968,6 +968,7 @@ def run(rank, world_size, args):
train_giga_cuts = gigaspeech.train_S_cuts()
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
+ train_giga_cuts = train_giga_cuts.repeat(times=None)
if args.enable_musan:
cuts_musan = load_manifest(