diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
new file mode 100644
index 000000000..efea5366b
--- /dev/null
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
@@ -0,0 +1,152 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-100h
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+jobs:
+ run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
+ if: github.event.label.name == 'ready' || github.event_name == 'push'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+ torch: ["1.10.0"]
+ torchaudio: ["0.10.0"]
+ k2-version: ["1.9.dev20211101"]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v1
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install Python dependencies
+ run: |
+ python3 -m pip install --upgrade pip pytest
+ # numpy 1.20.x does not support python 3.6
+ pip install numpy==1.19
+ pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
+ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
+
+ python3 -m pip install git+https://github.com/lhotse-speech/lhotse
+ python3 -m pip install kaldifeat
+ # We are in ./icefall and there is a file: requirements.txt in it
+ pip install -r requirements.txt
+
+ - name: Install graphviz
+ shell: bash
+ run: |
+ python3 -m pip install -qq graphviz
+ sudo apt-get -qq install graphviz
+
+ - name: Download pre-trained model
+ shell: bash
+ run: |
+ sudo apt-get -qq install git-lfs tree sox
+ cd egs/librispeech/ASR
+ mkdir tmp
+ cd tmp
+ git lfs install
+ git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
+
+ cd ..
+ tree tmp
+ soxi tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav
+ ls -lh tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav
+
+ - name: Run greedy search decoding (max-sym-per-frame 1)
+ shell: bash
+ run: |
+ export PYTHONPATH=$PWD:PYTHONPATH
+ cd egs/librispeech/ASR
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame 1 \
+ --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
+ --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
+
+ - name: Run greedy search decoding (max-sym-per-frame 2)
+ shell: bash
+ run: |
+ export PYTHONPATH=$PWD:PYTHONPATH
+ cd egs/librispeech/ASR
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame 2 \
+ --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
+ --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
+
+ - name: Run greedy search decoding (max-sym-per-frame 3)
+ shell: bash
+ run: |
+ export PYTHONPATH=$PWD:PYTHONPATH
+ cd egs/librispeech/ASR
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame 3 \
+ --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
+ --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
+
+ - name: Run beam search decoding
+ shell: bash
+ run: |
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ cd egs/librispeech/ASR
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method beam_search \
+ --beam-size 4 \
+ --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
+ --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
+
+ - name: Run modified beam search decoding
+ shell: bash
+ run: |
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ cd egs/librispeech/ASR
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method modified_beam_search \
+ --beam-size 4 \
+ --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \
+ --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \
+ ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav
diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md
index c8ee98d7d..211a7d120 100644
--- a/egs/librispeech/ASR/README.md
+++ b/egs/librispeech/ASR/README.md
@@ -9,11 +9,12 @@ for how to run models in this recipe.
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
-| | Encoder | Decoder |
-|------------------------|-----------|--------------------|
-| `transducer` | Conformer | LSTM |
-| `transducer_stateless` | Conformer | Embedding + Conv1d |
-| `transducer_lstm ` | LSTM | LSTM |
+| | Encoder | Decoder | Comment |
+|---------------------------------------|-----------|--------------------|---------------------------------------------------|
+| `transducer` | Conformer | LSTM | |
+| `transducer_stateless` | Conformer | Embedding + Conv1d | |
+| `transducer_lstm` | LSTM | LSTM | |
+| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
diff --git a/egs/librispeech/ASR/RESULTS-100hours.md b/egs/librispeech/ASR/RESULTS-100hours.md
new file mode 100644
index 000000000..40245c917
--- /dev/null
+++ b/egs/librispeech/ASR/RESULTS-100hours.md
@@ -0,0 +1,75 @@
+# Results for train-clean-100
+
+This page shows the WERs for test-clean/test-other using only
+train-clean-100 subset as training data.
+
+## Conformer encoder + embedding decoder
+
+### 2022-02-21
+
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|------------------------------------------|
+| greedy search (max sym per frame 1) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
+| greedy search (max sym per frame 2) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
+| greedy search (max sym per frame 3) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
+| modified beam search (beam size 4) | 6.31 | 16.3 | --epoch 57, --avg 17, --max-duration 100 |
+
+
+The training command for reproducing is given below:
+
+```bash
+cd egs/librispeech/ASR/
+./prepare.sh
+./prepare_giga_speech.sh
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+./transducer_stateless_multi_datasets/train.py \
+ --world-size 2 \
+ --num-epochs 60 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless_multi_datasets/exp-100-2 \
+ --full-libri 0 \
+ --max-duration 300 \
+ --lr-factor 1 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --modified-transducer-prob 0.25
+ --giga-prob 0.2
+```
+
+The decoding command is given below:
+
+```bash
+for epoch in 57; do
+ for avg in 17; do
+ for sym in 1 2 3; do
+ ./transducer_stateless_multi_datasets/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_multi_datasets/exp-100-2 \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --max-duration 100 \
+ --context-size 2 \
+ --max-sym-per-frame $sym
+ done
+ done
+done
+
+epoch=57
+avg=17
+./transducer_stateless_multi_datasets/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_multi_datasets/exp-100-2 \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --max-duration 100 \
+ --context-size 2 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+```
+
+The tensorboard log is available at
+
+
+A pre-trained model and decoding logs can be found at
+
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index 4082c3e97..136afe9c0 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -191,15 +191,10 @@ def get_transducer_model(params: AttributeDict):
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
- decoder_giga = get_decoder_model(params)
- joiner_giga = get_joiner_model(params)
-
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
- decoder_giga=decoder_giga,
- joiner_giga=joiner_giga,
)
return model
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index 5687260df..7d14d011d 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -20,22 +20,23 @@
# to a single one using model averaging.
"""
Usage:
-./transducer_stateless/export.py \
- --exp-dir ./transducer_stateless/exp \
+./transducer_stateless_multi_datasets/export.py \
+ --exp-dir ./transducer_stateless_multi_datasets/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
-To use the generated file with `transducer_stateless/decode.py`, you can do:
+To use the generated file with `transducer_stateless_multi_datasets/decode.py`,
+you can do::
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
- ./transducer_stateless/decode.py \
- --exp-dir ./transducer_stateless/exp \
+ ./transducer_stateless_multi_datasets/decode.py \
+ --exp-dir ./transducer_stateless_multi_datasets/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
@@ -84,7 +85,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
- default="transducer_stateless/exp",
+ default="transducer_stateless_multi_datasets/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
@@ -218,7 +219,9 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
- model.load_state_dict(average_checkpoints(filenames, device=device))
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
model.eval()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py
index ecffcf9ff..00b7c8334 100644
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py
@@ -49,7 +49,7 @@ class LibriSpeech:
return load_manifest(f)
def train_other_500_cuts(self) -> CutSet:
- f = self.args.manifest_dir / "cuts_train-other-500.json.gz"
+ f = self.manifest_dir / "cuts_train-other-500.json.gz"
logging.info(f"About to get train-other-500 cuts from {f}")
return load_manifest(f)
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py
index 919c19a86..8141f9a83 100644
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/model.py
@@ -15,6 +15,7 @@
# limitations under the License.
import random
+from typing import Optional
import k2
import torch
@@ -34,8 +35,8 @@ class Transducer(nn.Module):
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
- decoder_giga: nn.Module,
- joiner_giga: nn.Module,
+ decoder_giga: Optional[nn.Module] = None,
+ joiner_giga: Optional[nn.Module] = None,
):
"""
Args:
@@ -60,7 +61,9 @@ class Transducer(nn.Module):
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
- assert hasattr(decoder_giga, "blank_id")
+
+ if decoder_giga is not None:
+ assert hasattr(decoder_giga, "blank_id")
self.encoder = encoder
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 51b9d19da..720151ea0 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -738,8 +738,13 @@ def run(rank, world_size, args):
# XS 10 hours
# DEV 12 hours
# Test 40 hours
- # train_giga_cuts = gigaspeech.train_M_cuts()
- train_giga_cuts = gigaspeech.train_S_cuts()
+ if params.full_libri:
+ logging.info("Using the L subset of GigaSpeech (2.5k hours)")
+ train_giga_cuts = gigaspeech.train_L_cuts()
+ else:
+ logging.info("Using the S subset of GigaSpeech (250 hours)")
+ train_giga_cuts = gigaspeech.train_S_cuts()
+
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
if args.enable_musan:
@@ -868,7 +873,7 @@ def main():
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
- assert 0 < args.giga_prob < 1, args.giga_prob
+ assert 0 <= args.giga_prob < 1, args.giga_prob
world_size = args.world_size
assert world_size >= 1