diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml new file mode 100644 index 000000000..7af2299a4 --- /dev/null +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -0,0 +1,108 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-pre-trained-tranducer-stateless + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_pre_trained_transducer_stateless: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.9.dev20211101"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip pytest + # numpy 1.20.x does not support python 3.6 + pip install numpy==1.19 + pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + + python3 -m pip install git+https://github.com/lhotse-speech/lhotse + python3 -m pip install kaldifeat + # We are in ./icefall and there is a file: requirements.txt in it + pip install -r requirements.txt + + - name: Install graphviz + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Download pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + cd egs/librispeech/ASR + mkdir tmp + cd tmp + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22 + cd .. + tree tmp + soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav + ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav + + - name: Run greedy search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav + + - name: Run beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/run-pretrained.yml b/.github/workflows/run-pretrained.yml index 710ca2603..1758a3521 100644 --- a/.github/workflows/run-pretrained.yml +++ b/.github/workflows/run-pretrained.yml @@ -30,7 +30,7 @@ jobs: strategy: matrix: os: [ubuntu-18.04] - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] torch: ["1.10.0"] torchaudio: ["0.10.0"] k2-version: ["1.9.dev20211101"] diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index baa9c1727..f2c63a3b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: # os: [ubuntu-18.04, macos-10.15] # disable macOS test for now. os: [ubuntu-18.04] - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8] torch: ["1.8.0", "1.10.0"] torchaudio: ["0.8.0", "0.10.0"] k2-version: ["1.9.dev20211101"] @@ -106,6 +106,12 @@ jobs: if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then cd ../transducer pytest -v -s + + cd ../transducer_stateless + pytest -v -s + + cd ../transducer_lstm + pytest -v -s fi - name: Run tests @@ -125,4 +131,10 @@ jobs: if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then cd ../transducer pytest -v -s + + cd ../transducer_stateless + pytest -v -s + + cd ../transducer_lstm + pytest -v -s fi diff --git a/README.md b/README.md index 23389d483..931fb0198 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,12 @@ We do provide a Colab notebook for this recipe. ### LibriSpeech -We provide 3 models for this recipe: +We provide 4 models for this recipe: - [conformer CTC model][LibriSpeech_conformer_ctc] - [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc] -- [RNN-T Conformer model][LibriSpeech_transducer] +- [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer] +- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless] #### Conformer CTC Model @@ -62,9 +63,9 @@ The WER for this model is: We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) -#### RNN-T Conformer model +#### Transducer: Conformer encoder + LSTM decoder -Using Conformer as encoder. +Using Conformer as encoder and LSTM as decoder. The best WER with greedy search is: @@ -74,6 +75,21 @@ The best WER with greedy search is: We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) +#### Transducer: Conformer encoder + Embedding decoder + +Using Conformer as encoder. The decoder consists of 1 embedding layer +and 1 convolutional layer. + +The best WER using beam search with beam size 4 is: + +| | test-clean | test-other | +|-----|------------|------------| +| WER | 2.92 | 7.37 | + +Note: No auxiliary losses are used in the training and no LMs are used +in the decoding. + +We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Lm37sNajIpkV4HTzMDF7sn9l0JpfmekN?usp=sharing) ### Aishell @@ -143,6 +159,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc [LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc [LibriSpeech_transducer]: egs/librispeech/ASR/transducer +[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless [Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc [Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc [TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index ae0c2684d..c8ee98d7d 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,3 +1,20 @@ +# Introduction + Please refer to for how to run models in this recipe. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | +|------------------------|-----------|--------------------| +| `transducer` | Conformer | LSTM | +| `transducer_stateless` | Conformer | Embedding + Conv1d | +| `transducer_lstm ` | LSTM | LSTM | + +The decoder in `transducer_stateless` is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 1f5edb571..317b1591a 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,10 +1,69 @@ ## Results -### LibriSpeech BPE training results (RNN-T) +### LibriSpeech BPE training results (Transducer) + +#### 2021-12-22 +Conformer encoder + non-current decoder. The decoder +contains only an embedding layer and a Conv1d (with kernel size 2). + +The WERs are + +| | test-clean | test-other | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 2.99 | 7.52 | --epoch 20, --avg 10, --max-duration 100 | +| beam search (beam size 2) | 2.95 | 7.43 | | +| beam search (beam size 3) | 2.94 | 7.37 | | +| beam search (beam size 4) | 2.92 | 7.37 | | +| beam search (beam size 5) | 2.93 | 7.38 | | +| beam search (beam size 8) | 2.92 | 7.38 | | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir transducer_stateless/exp-full \ + --full-libri 1 \ + --max-duration 250 \ + --lr-factor 3 +``` + +The tensorboard training log can be found at + + +The decoding command is: +``` +epoch=20 +avg=10 + +## greedy search +./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp-full \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 + +## beam search +./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp-full \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 +``` + #### 2021-12-17 +Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`. -RNN-T + Conformer encoder +Conformer encoder + LSTM decoder. The best WER is @@ -12,7 +71,7 @@ The best WER is |-----|------------|------------| | WER | 3.16 | 7.71 | -using `--epoch 26 --avg 12` during decoding with greedy search. +using `--epoch 26 --avg 12` with **greedy search**. The training command to reproduce the above WER is: diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index 8a4d3ca69..cb9afd8a2 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos -assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" -) - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf @@ -115,6 +110,11 @@ class Transducer(nn.Module): # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" + ) + loss = torchaudio.functional.rnnt_loss( logits=logits, targets=y_padded, diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index 8a4d3ca69..cb9afd8a2 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos -assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" -) - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf @@ -115,6 +110,11 @@ class Transducer(nn.Module): # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" + ) + loss = torchaudio.functional.rnnt_loss( logits=logits, targets=y_padded, diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 88f23e922..45118a8bc 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -15,8 +15,9 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional +import numpy as np import torch from model import Transducer @@ -35,21 +36,35 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out = model.decoder(sos) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + T = encoder_out.size(1) t = 0 - hyp = [] - - sym_per_frame = 0 - sym_per_utt = 0 + hyp = [blank_id] * context_size + # Maximum symbols per utterance. max_sym_per_utt = 1000 + + # If at frame t, it decodes more than this number of symbols, + # it will move to the next step t+1 max_sym_per_frame = 3 + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + while t < T and sym_per_utt < max_sym_per_utt: # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] @@ -57,14 +72,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: logits = model.joiner(current_encoder_out, decoder_out) # logits is (1, 1, 1, vocab_size) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - # TODO: Use logits.argmax() - y = log_prob.argmax() + y = logits.argmax().item() if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out = model.decoder(y) + hyp.append(y) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) sym_per_utt += 1 sym_per_frame += 1 @@ -72,24 +87,135 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: if y == blank_id or sym_per_frame > max_sym_per_frame: sym_per_frame = 0 t += 1 + hyp = hyp[context_size:] # remove blanks return hyp @dataclass class Hypothesis: - ys: List[int] # the predicted sequences so far - log_prob: float # The log prob of ys + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] - # Optional decoder state. We assume it is LSTM for now, - # so the state is a tuple (h, c) - decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + # The log prob of ys + log_prob: float + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self): + return self._data + + # def add(self, ys: List[int], log_prob: float): + def add(self, hyp: Hypothesis): + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] + old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: float) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + that have `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for key, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) def beam_search( model: Transducer, encoder_out: torch.Tensor, - beam: int = 5, + beam: int = 4, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -111,110 +237,98 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id + context_size = model.decoder.context_size + device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + T = encoder_out.size(1) t = 0 - B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] - max_u = 20000 # terminate after this number of steps - u = 0 - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) - while t < T and u < max_u: + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on A = B - B = [] - # for hyp in A: - # for h in A: - # if h.ys == hyp.ys[:-1]: - # # update the score of hyp - # decoder_input = torch.tensor( - # [h.ys[-1]], device=device - # ).reshape(1, 1) - # decoder_out, _ = model.decoder( - # decoder_input, h.decoder_state - # ) - # logits = model.joiner(current_encoder_out, decoder_out) - # log_prob = logits.log_softmax(dim=-1) - # log_prob = log_prob.squeeze() - # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() + B = HypothesisList() - while u < max_u: - y_star = max(A, key=lambda hyp: hyp.log_prob) + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() A.remove(y_star) - # Note: y_star.ys is unhashable, i.e., cannot be used - # as a key into a dict - cached_key = "_".join(map(str, y_star.ys)) + cached_key = y_star.key - if cached_key not in cache: + if cached_key not in decoder_cache: decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) + [y_star.ys[-context_size:]], device=device + ).reshape(1, context_size) - decoder_out, decoder_state = model.decoder( - decoder_input, - y_star.decoder_state, - ) - cache[cached_key] = (decoder_out, decoder_state) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_cache[cached_key] = decoder_out else: - decoder_out, decoder_state = cache[cached_key] + decoder_out = decoder_cache[cached_key] - logits = model.joiner(current_encoder_out, decoder_out) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner(current_encoder_out, decoder_out) - # If we choose blank here, add the new hypothesis to B. - # Otherwise, add the new hypothesis to A + # TODO(fangjun): Ccale the blank posterior - # First, choose blank + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol skip_log_prob = log_prob[blank_id] new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() # ys[:] returns a copy of ys - new_y_star = Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - # Caution: Use y_star.decoder_state here - decoder_state=y_star.decoder_state, - ) - B.append(new_y_star) + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - # Second, choose other labels - for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i == blank_id: continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - decoder_state=decoder_state, - ) - A.append(new_hyp) - u += 1 - # check whether B contains more than "beam" elements more probable + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable # than the most probable in A - A_most_probable = max(A, key=lambda hyp: hyp.log_prob) - B = sorted( - [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], - key=lambda hyp: hyp.log_prob, - reverse=True, - ) - if len(B) >= beam: - B = B[:beam] + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) break + t += 1 - best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) - ys = best_hyp.ys[1:] # [1:] to remove the blank + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks return ys diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 2fa5cc55e..82175e8db 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -24,15 +24,15 @@ Usage: --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method greedy_search -(2) beam search +(2) beam search ./transducer_stateless/decode.py \ --epoch 14 \ --avg 7 \ --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method beam_search \ - --beam-size 8 + --beam-size 4 """ @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, + default=20, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=55, + default=10, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--beam-size", type=int, - default=5, + default=4, help="Used only when --decoding-method is beam_search", ) @@ -130,7 +130,8 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, "use_feat_batchnorm": True, - # decoder params + # parameters for decoder + "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -158,6 +159,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + context_size=params.context_size, ) return decoder @@ -392,9 +394,8 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index 9d6b3aaf2..cedbc937e 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F class Decoder(nn.Module): @@ -35,6 +36,7 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, + context_size: int, ): """ Args: @@ -44,6 +46,9 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() self.embedding = nn.Embedding( @@ -53,13 +58,40 @@ class Decoder(nn.Module): ) self.blank_id = blank_id - def forward(self, y: torch.Tensor) -> torch.Tensor: + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ Args: y: A 2-D tensor of shape (N, U) with blank prepended. + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. Returns: Return a tensor of shape (N, U, embedding_dim). """ embeding_out = self.embedding(y) + if self.context_size > 1: + embeding_out = embeding_out.permute(0, 2, 1) + if need_pad is True: + embeding_out = F.pad( + embeding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embeding_out.size(-1) == self.context_size + embeding_out = self.conv(embeding_out) + embeding_out = embeding_out.permute(0, 2, 1) return embeding_out diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py new file mode 100755 index 000000000..a877b5067 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./transducer_stateless/export.py \ + --exp-dir ./transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `transducer_stateless/decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./transducer_stateless/decode.py \ + --exp-dir ./transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 1 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "use_feat_batchnorm": True, + # parameters for decoder + "context_size": 2, # tri-gram + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict): + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + return encoder + + +def get_decoder_model(params: AttributeDict): + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.encoder_out_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict): + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict): + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 7053f621e..2f0f9a183 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos -assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" -) - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf @@ -113,6 +108,11 @@ class Transducer(nn.Module): # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" + ) + loss = torchaudio.functional.rnnt_loss( logits=logits, targets=y_padded, diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py new file mode 100755 index 000000000..49efa6749 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +(1) beam search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +You can also use `./transducer_stateless/exp/epoch-xx.pt`. + +Note: ./transducer_stateless/exp/pretrained.pt is generated by +./transducer_stateless/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import beam_search, greedy_search +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer +from torch.nn.utils.rnn import pad_sequence + +from icefall.env import get_env_info +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="Used only when --method is beam_search", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "use_feat_batchnorm": True, + # parameters for decoder + "context_size": 2, # tri-gram + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict): + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + return encoder + + +def get_decoder_model(params: AttributeDict): + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.encoder_out_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict): + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict): + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + with torch.no_grad(): + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search(model=model, encoder_out=encoder_out_i) + elif params.method == "beam_search": + hyp = beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer_stateless/test_decoder.py b/egs/librispeech/ASR/transducer_stateless/test_decoder.py new file mode 100755 index 000000000..3a653c1b7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/test_decoder.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless/test_decoder.py +""" + +import torch +from decoder import Decoder + + +def test_decoder(): + vocab_size = 3 + blank_id = 0 + embedding_dim = 128 + context_size = 4 + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + context_size=context_size, + ) + N = 100 + U = 20 + x = torch.randint(low=0, high=vocab_size, size=(N, U)) + y = decoder(x) + assert y.shape == (N, U, embedding_dim) + + # for inference + x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) + y = decoder(x, need_pad=False) + assert y.shape == (N, 1, embedding_dim) + + +def main(): + test_decoder() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index e20aedf9b..a2bf4700c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -92,7 +92,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=78, + default=30, help="Number of epochs to train.", ) @@ -202,6 +202,8 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, "use_feat_batchnorm": True, + # parameters for decoder + "context_size": 2, # tri-gram # parameters for Noam "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k @@ -233,6 +235,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + context_size=params.context_size, ) return decoder