From fb6a57e9e01dd8aae2af2a6b4568daad8bc8ab32 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 23 Dec 2021 07:55:02 +0800 Subject: [PATCH 1/2] Increase the size of the context in the RNN-T decoder. (#153) --- .../run-pretrained-transducer-stateless.yml | 108 ++++++ .github/workflows/run-pretrained.yml | 2 +- .github/workflows/test.yml | 14 +- README.md | 25 +- egs/librispeech/ASR/README.md | 17 + egs/librispeech/ASR/RESULTS.md | 65 +++- egs/librispeech/ASR/transducer/model.py | 10 +- egs/librispeech/ASR/transducer_lstm/model.py | 10 +- .../ASR/transducer_stateless/beam_search.py | 304 +++++++++++------ .../ASR/transducer_stateless/decode.py | 17 +- .../ASR/transducer_stateless/decoder.py | 34 +- .../ASR/transducer_stateless/export.py | 244 ++++++++++++++ .../ASR/transducer_stateless/model.py | 10 +- .../ASR/transducer_stateless/pretrained.py | 307 ++++++++++++++++++ .../ASR/transducer_stateless/test_decoder.py | 58 ++++ .../ASR/transducer_stateless/train.py | 5 +- 16 files changed, 1101 insertions(+), 129 deletions(-) create mode 100644 .github/workflows/run-pretrained-transducer-stateless.yml create mode 100755 egs/librispeech/ASR/transducer_stateless/export.py create mode 100755 egs/librispeech/ASR/transducer_stateless/pretrained.py create mode 100755 egs/librispeech/ASR/transducer_stateless/test_decoder.py diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml new file mode 100644 index 000000000..7af2299a4 --- /dev/null +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -0,0 +1,108 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-pre-trained-tranducer-stateless + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_pre_trained_transducer_stateless: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.9.dev20211101"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip pytest + # numpy 1.20.x does not support python 3.6 + pip install numpy==1.19 + pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + + python3 -m pip install git+https://github.com/lhotse-speech/lhotse + python3 -m pip install kaldifeat + # We are in ./icefall and there is a file: requirements.txt in it + pip install -r requirements.txt + + - name: Install graphviz + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Download pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + cd egs/librispeech/ASR + mkdir tmp + cd tmp + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22 + cd .. + tree tmp + soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav + ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav + + - name: Run greedy search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav + + - name: Run beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/librispeech/ASR + ./transducer_stateless/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/run-pretrained.yml b/.github/workflows/run-pretrained.yml index 710ca2603..1758a3521 100644 --- a/.github/workflows/run-pretrained.yml +++ b/.github/workflows/run-pretrained.yml @@ -30,7 +30,7 @@ jobs: strategy: matrix: os: [ubuntu-18.04] - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] torch: ["1.10.0"] torchaudio: ["0.10.0"] k2-version: ["1.9.dev20211101"] diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index baa9c1727..f2c63a3b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: # os: [ubuntu-18.04, macos-10.15] # disable macOS test for now. os: [ubuntu-18.04] - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8] torch: ["1.8.0", "1.10.0"] torchaudio: ["0.8.0", "0.10.0"] k2-version: ["1.9.dev20211101"] @@ -106,6 +106,12 @@ jobs: if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then cd ../transducer pytest -v -s + + cd ../transducer_stateless + pytest -v -s + + cd ../transducer_lstm + pytest -v -s fi - name: Run tests @@ -125,4 +131,10 @@ jobs: if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then cd ../transducer pytest -v -s + + cd ../transducer_stateless + pytest -v -s + + cd ../transducer_lstm + pytest -v -s fi diff --git a/README.md b/README.md index 23389d483..931fb0198 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,12 @@ We do provide a Colab notebook for this recipe. ### LibriSpeech -We provide 3 models for this recipe: +We provide 4 models for this recipe: - [conformer CTC model][LibriSpeech_conformer_ctc] - [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc] -- [RNN-T Conformer model][LibriSpeech_transducer] +- [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer] +- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless] #### Conformer CTC Model @@ -62,9 +63,9 @@ The WER for this model is: We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) -#### RNN-T Conformer model +#### Transducer: Conformer encoder + LSTM decoder -Using Conformer as encoder. +Using Conformer as encoder and LSTM as decoder. The best WER with greedy search is: @@ -74,6 +75,21 @@ The best WER with greedy search is: We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) +#### Transducer: Conformer encoder + Embedding decoder + +Using Conformer as encoder. The decoder consists of 1 embedding layer +and 1 convolutional layer. + +The best WER using beam search with beam size 4 is: + +| | test-clean | test-other | +|-----|------------|------------| +| WER | 2.92 | 7.37 | + +Note: No auxiliary losses are used in the training and no LMs are used +in the decoding. + +We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Lm37sNajIpkV4HTzMDF7sn9l0JpfmekN?usp=sharing) ### Aishell @@ -143,6 +159,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc [LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc [LibriSpeech_transducer]: egs/librispeech/ASR/transducer +[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless [Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc [Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc [TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index ae0c2684d..c8ee98d7d 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,3 +1,20 @@ +# Introduction + Please refer to for how to run models in this recipe. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | +|------------------------|-----------|--------------------| +| `transducer` | Conformer | LSTM | +| `transducer_stateless` | Conformer | Embedding + Conv1d | +| `transducer_lstm ` | LSTM | LSTM | + +The decoder in `transducer_stateless` is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 1f5edb571..317b1591a 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,10 +1,69 @@ ## Results -### LibriSpeech BPE training results (RNN-T) +### LibriSpeech BPE training results (Transducer) + +#### 2021-12-22 +Conformer encoder + non-current decoder. The decoder +contains only an embedding layer and a Conv1d (with kernel size 2). + +The WERs are + +| | test-clean | test-other | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 2.99 | 7.52 | --epoch 20, --avg 10, --max-duration 100 | +| beam search (beam size 2) | 2.95 | 7.43 | | +| beam search (beam size 3) | 2.94 | 7.37 | | +| beam search (beam size 4) | 2.92 | 7.37 | | +| beam search (beam size 5) | 2.93 | 7.38 | | +| beam search (beam size 8) | 2.92 | 7.38 | | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir transducer_stateless/exp-full \ + --full-libri 1 \ + --max-duration 250 \ + --lr-factor 3 +``` + +The tensorboard training log can be found at + + +The decoding command is: +``` +epoch=20 +avg=10 + +## greedy search +./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp-full \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 + +## beam search +./transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless/exp-full \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 +``` + #### 2021-12-17 +Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`. -RNN-T + Conformer encoder +Conformer encoder + LSTM decoder. The best WER is @@ -12,7 +71,7 @@ The best WER is |-----|------------|------------| | WER | 3.16 | 7.71 | -using `--epoch 26 --avg 12` during decoding with greedy search. +using `--epoch 26 --avg 12` with **greedy search**. The training command to reproduce the above WER is: diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index 8a4d3ca69..cb9afd8a2 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos -assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" -) - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf @@ -115,6 +110,11 @@ class Transducer(nn.Module): # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" + ) + loss = torchaudio.functional.rnnt_loss( logits=logits, targets=y_padded, diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index 8a4d3ca69..cb9afd8a2 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos -assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" -) - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf @@ -115,6 +110,11 @@ class Transducer(nn.Module): # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" + ) + loss = torchaudio.functional.rnnt_loss( logits=logits, targets=y_padded, diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 88f23e922..45118a8bc 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -15,8 +15,9 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional +import numpy as np import torch from model import Transducer @@ -35,21 +36,35 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out = model.decoder(sos) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + T = encoder_out.size(1) t = 0 - hyp = [] - - sym_per_frame = 0 - sym_per_utt = 0 + hyp = [blank_id] * context_size + # Maximum symbols per utterance. max_sym_per_utt = 1000 + + # If at frame t, it decodes more than this number of symbols, + # it will move to the next step t+1 max_sym_per_frame = 3 + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + while t < T and sym_per_utt < max_sym_per_utt: # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] @@ -57,14 +72,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: logits = model.joiner(current_encoder_out, decoder_out) # logits is (1, 1, 1, vocab_size) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - # TODO: Use logits.argmax() - y = log_prob.argmax() + y = logits.argmax().item() if y != blank_id: - hyp.append(y.item()) - y = y.reshape(1, 1) - decoder_out = model.decoder(y) + hyp.append(y) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) sym_per_utt += 1 sym_per_frame += 1 @@ -72,24 +87,135 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: if y == blank_id or sym_per_frame > max_sym_per_frame: sym_per_frame = 0 t += 1 + hyp = hyp[context_size:] # remove blanks return hyp @dataclass class Hypothesis: - ys: List[int] # the predicted sequences so far - log_prob: float # The log prob of ys + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] - # Optional decoder state. We assume it is LSTM for now, - # so the state is a tuple (h, c) - decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + # The log prob of ys + log_prob: float + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self): + return self._data + + # def add(self, ys: List[int], log_prob: float): + def add(self, hyp: Hypothesis): + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] + old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: float) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + that have `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for key, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) def beam_search( model: Transducer, encoder_out: torch.Tensor, - beam: int = 5, + beam: int = 4, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -111,110 +237,98 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id + context_size = model.decoder.context_size + device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) - decoder_out, (h, c) = model.decoder(sos) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + T = encoder_out.size(1) t = 0 - B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)] - max_u = 20000 # terminate after this number of steps - u = 0 - cache: Dict[ - str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ] = {} + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) - while t < T and u < max_u: + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on A = B - B = [] - # for hyp in A: - # for h in A: - # if h.ys == hyp.ys[:-1]: - # # update the score of hyp - # decoder_input = torch.tensor( - # [h.ys[-1]], device=device - # ).reshape(1, 1) - # decoder_out, _ = model.decoder( - # decoder_input, h.decoder_state - # ) - # logits = model.joiner(current_encoder_out, decoder_out) - # log_prob = logits.log_softmax(dim=-1) - # log_prob = log_prob.squeeze() - # hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item() + B = HypothesisList() - while u < max_u: - y_star = max(A, key=lambda hyp: hyp.log_prob) + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() A.remove(y_star) - # Note: y_star.ys is unhashable, i.e., cannot be used - # as a key into a dict - cached_key = "_".join(map(str, y_star.ys)) + cached_key = y_star.key - if cached_key not in cache: + if cached_key not in decoder_cache: decoder_input = torch.tensor( - [y_star.ys[-1]], device=device - ).reshape(1, 1) + [y_star.ys[-context_size:]], device=device + ).reshape(1, context_size) - decoder_out, decoder_state = model.decoder( - decoder_input, - y_star.decoder_state, - ) - cache[cached_key] = (decoder_out, decoder_state) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_cache[cached_key] = decoder_out else: - decoder_out, decoder_state = cache[cached_key] + decoder_out = decoder_cache[cached_key] - logits = model.joiner(current_encoder_out, decoder_out) - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner(current_encoder_out, decoder_out) - # If we choose blank here, add the new hypothesis to B. - # Otherwise, add the new hypothesis to A + # TODO(fangjun): Ccale the blank posterior - # First, choose blank + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol skip_log_prob = log_prob[blank_id] new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() # ys[:] returns a copy of ys - new_y_star = Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - # Caution: Use y_star.decoder_state here - decoder_state=y_star.decoder_state, - ) - B.append(new_y_star) + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - # Second, choose other labels - for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i == blank_id: continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - decoder_state=decoder_state, - ) - A.append(new_hyp) - u += 1 - # check whether B contains more than "beam" elements more probable + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable # than the most probable in A - A_most_probable = max(A, key=lambda hyp: hyp.log_prob) - B = sorted( - [hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], - key=lambda hyp: hyp.log_prob, - reverse=True, - ) - if len(B) >= beam: - B = B[:beam] + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) break + t += 1 - best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:])) - ys = best_hyp.ys[1:] # [1:] to remove the blank + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks return ys diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 2fa5cc55e..82175e8db 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -24,15 +24,15 @@ Usage: --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method greedy_search -(2) beam search +(2) beam search ./transducer_stateless/decode.py \ --epoch 14 \ --avg 7 \ --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method beam_search \ - --beam-size 8 + --beam-size 4 """ @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, + default=20, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=55, + default=10, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--beam-size", type=int, - default=5, + default=4, help="Used only when --decoding-method is beam_search", ) @@ -130,7 +130,8 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, "use_feat_batchnorm": True, - # decoder params + # parameters for decoder + "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -158,6 +159,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + context_size=params.context_size, ) return decoder @@ -392,9 +394,8 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index 9d6b3aaf2..cedbc937e 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F class Decoder(nn.Module): @@ -35,6 +36,7 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, + context_size: int, ): """ Args: @@ -44,6 +46,9 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() self.embedding = nn.Embedding( @@ -53,13 +58,40 @@ class Decoder(nn.Module): ) self.blank_id = blank_id - def forward(self, y: torch.Tensor) -> torch.Tensor: + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ Args: y: A 2-D tensor of shape (N, U) with blank prepended. + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. Returns: Return a tensor of shape (N, U, embedding_dim). """ embeding_out = self.embedding(y) + if self.context_size > 1: + embeding_out = embeding_out.permute(0, 2, 1) + if need_pad is True: + embeding_out = F.pad( + embeding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embeding_out.size(-1) == self.context_size + embeding_out = self.conv(embeding_out) + embeding_out = embeding_out.permute(0, 2, 1) return embeding_out diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py new file mode 100755 index 000000000..a877b5067 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./transducer_stateless/export.py \ + --exp-dir ./transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `transducer_stateless/decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./transducer_stateless/decode.py \ + --exp-dir ./transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 1 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "use_feat_batchnorm": True, + # parameters for decoder + "context_size": 2, # tri-gram + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict): + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + return encoder + + +def get_decoder_model(params: AttributeDict): + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.encoder_out_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict): + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict): + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 7053f621e..2f0f9a183 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos -assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" -) - class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf @@ -113,6 +108,11 @@ class Transducer(nn.Module): # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" + ) + loss = torchaudio.functional.rnnt_loss( logits=logits, targets=y_padded, diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py new file mode 100755 index 000000000..49efa6749 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +(1) beam search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +You can also use `./transducer_stateless/exp/epoch-xx.pt`. + +Note: ./transducer_stateless/exp/pretrained.pt is generated by +./transducer_stateless/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import beam_search, greedy_search +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer +from torch.nn.utils.rnn import pad_sequence + +from icefall.env import get_env_info +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="Used only when --method is beam_search", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "use_feat_batchnorm": True, + # parameters for decoder + "context_size": 2, # tri-gram + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict): + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + return encoder + + +def get_decoder_model(params: AttributeDict): + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.encoder_out_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict): + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict): + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + with torch.no_grad(): + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search(model=model, encoder_out=encoder_out_i) + elif params.method == "beam_search": + hyp = beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer_stateless/test_decoder.py b/egs/librispeech/ASR/transducer_stateless/test_decoder.py new file mode 100755 index 000000000..3a653c1b7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/test_decoder.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless/test_decoder.py +""" + +import torch +from decoder import Decoder + + +def test_decoder(): + vocab_size = 3 + blank_id = 0 + embedding_dim = 128 + context_size = 4 + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + context_size=context_size, + ) + N = 100 + U = 20 + x = torch.randint(low=0, high=vocab_size, size=(N, U)) + y = decoder(x) + assert y.shape == (N, U, embedding_dim) + + # for inference + x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) + y = decoder(x, need_pad=False) + assert y.shape == (N, 1, embedding_dim) + + +def main(): + test_decoder() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index e20aedf9b..a2bf4700c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -92,7 +92,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=78, + default=30, help="Number of epochs to train.", ) @@ -202,6 +202,8 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, "use_feat_batchnorm": True, + # parameters for decoder + "context_size": 2, # tri-gram # parameters for Noam "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k @@ -233,6 +235,7 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, + context_size=params.context_size, ) return decoder From 5b6699a8354b70b23b252b371c612a35ed186ec2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 23 Dec 2021 13:54:25 +0800 Subject: [PATCH 2/2] Minor fixes to the RNN-T Conformer model (#152) * Disable weight decay. * Remove input feature batchnorm.. * Replace BatchNorm in the Conformer model with LayerNorm. * Use tanh in the joint network. * Remove sos ID. * Reduce the number of decoder layers from 4 to 2. * Minor fixes. * Fix typos. --- ...d.yml => run-pretrained-conformer-ctc.yml} | 0 .../run-pretrained-transducer-stateless.yml | 2 +- .../workflows/run-pretrained-transducer.yml | 109 ++++++++++++++++++ README.md | 2 +- egs/librispeech/ASR/RESULTS.md | 23 ++-- egs/librispeech/ASR/transducer/beam_search.py | 3 +- egs/librispeech/ASR/transducer/conformer.py | 16 ++- egs/librispeech/ASR/transducer/decode.py | 10 +- egs/librispeech/ASR/transducer/decoder.py | 4 - egs/librispeech/ASR/transducer/export.py | 14 +-- egs/librispeech/ASR/transducer/joiner.py | 3 +- egs/librispeech/ASR/transducer/model.py | 6 +- egs/librispeech/ASR/transducer/pretrained.py | 6 +- .../ASR/transducer/test_conformer.py | 1 - .../ASR/transducer/test_decoder.py | 2 - .../ASR/transducer/test_transducer.py | 3 - .../ASR/transducer/test_transformer.py | 1 - egs/librispeech/ASR/transducer/train.py | 17 +-- egs/librispeech/ASR/transducer/transformer.py | 11 -- 19 files changed, 147 insertions(+), 86 deletions(-) rename .github/workflows/{run-pretrained.yml => run-pretrained-conformer-ctc.yml} (100%) create mode 100644 .github/workflows/run-pretrained-transducer.yml diff --git a/.github/workflows/run-pretrained.yml b/.github/workflows/run-pretrained-conformer-ctc.yml similarity index 100% rename from .github/workflows/run-pretrained.yml rename to .github/workflows/run-pretrained-conformer-ctc.yml diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 7af2299a4..026d3967c 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-pre-trained-tranducer-stateless +name: run-pre-trained-trandsucer-stateless on: push: diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml new file mode 100644 index 000000000..f0ebddba3 --- /dev/null +++ b/.github/workflows/run-pretrained-transducer.yml @@ -0,0 +1,109 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-pre-trained-transducer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_pre_trained_transducer: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.9.dev20211101"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip pytest + # numpy 1.20.x does not support python 3.6 + pip install numpy==1.19 + pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + + python3 -m pip install git+https://github.com/lhotse-speech/lhotse + python3 -m pip install kaldifeat + # We are in ./icefall and there is a file: requirements.txt in it + pip install -r requirements.txt + + - name: Install graphviz + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Download pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + cd egs/librispeech/ASR + mkdir tmp + cd tmp + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23 + + cd .. + tree tmp + soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav + ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav + + - name: Run greedy search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./transducer/pretrained.py \ + --method greedy_search \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav + + - name: Run beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/librispeech/ASR + ./transducer/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav diff --git a/README.md b/README.md index 931fb0198..f0a678839 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ The best WER with greedy search is: | | test-clean | test-other | |-----|------------|------------| -| WER | 3.16 | 7.71 | +| WER | 3.07 | 7.51 | We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 317b1591a..aab2b61e0 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -2,7 +2,10 @@ ### LibriSpeech BPE training results (Transducer) -#### 2021-12-22 +#### Conformer encoder + embedding decoder + +Using commit `fb6a57e9e01dd8aae2af2a6b4568daad8bc8ab32`. + Conformer encoder + non-current decoder. The decoder contains only an embedding layer and a Conv1d (with kernel size 2). @@ -60,8 +63,8 @@ avg=10 ``` -#### 2021-12-17 -Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`. +#### Conformer encoder + LSTM decoder +Using commit `TODO`. Conformer encoder + LSTM decoder. @@ -69,9 +72,9 @@ The best WER is | | test-clean | test-other | |-----|------------|------------| -| WER | 3.16 | 7.71 | +| WER | 3.07 | 7.51 | -using `--epoch 26 --avg 12` with **greedy search**. +using `--epoch 34 --avg 11` with **greedy search**. The training command to reproduce the above WER is: @@ -80,19 +83,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" ./transducer/train.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 35 \ --start-epoch 0 \ --exp-dir transducer/exp-lr-2.5-full \ --full-libri 1 \ - --max-duration 250 \ + --max-duration 180 \ --lr-factor 2.5 ``` The decoding command is: ``` -epoch=26 -avg=12 +epoch=34 +avg=11 ./transducer/decode.py \ --epoch $epoch \ @@ -102,7 +105,7 @@ avg=12 --max-duration 100 ``` -You can find the tensorboard log at: +You can find the tensorboard log at: ### LibriSpeech BPE training results (Conformer-CTC) diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index dfc22fcf8..f45d06ce9 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -111,7 +111,6 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id device = model.device sos = torch.tensor([blank_id], device=device).reshape(1, 1) @@ -192,7 +191,7 @@ def beam_search( # Second, choose other labels for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): + if i == blank_id: continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v diff --git a/egs/librispeech/ASR/transducer/conformer.py b/egs/librispeech/ASR/transducer/conformer.py index 245aaa428..81d7708f9 100644 --- a/egs/librispeech/ASR/transducer/conformer.py +++ b/egs/librispeech/ASR/transducer/conformer.py @@ -56,7 +56,6 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -69,7 +68,6 @@ class Conformer(Transformer): dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -107,11 +105,6 @@ class Conformer(Transformer): - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.norm = nn.BatchNorm1d(channels) + self.norm = nn.LayerNorm(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, @@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 80b72a89f..ef0992618 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=26, + default=34, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=12, + default=11, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -129,10 +129,9 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, # decoder params "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, + "num_decoder_layers": 2, "decoder_hidden_dim": 512, "env_info": get_env_info(), } @@ -151,7 +150,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -161,7 +159,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -401,7 +398,6 @@ def main(): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py index 2f6bf4c07..7b529ac19 100644 --- a/egs/librispeech/ASR/transducer/decoder.py +++ b/egs/librispeech/ASR/transducer/decoder.py @@ -27,7 +27,6 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, - sos_id: int, num_layers: int, hidden_dim: int, output_dim: int, @@ -42,8 +41,6 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. - sos_id: - The ID of the SOS symbol. num_layers: Number of LSTM layers. hidden_dim: @@ -71,7 +68,6 @@ class Decoder(nn.Module): dropout=rnn_dropout, ) self.blank_id = blank_id - self.sos_id = sos_id self.output_linear = nn.Linear(hidden_dim, output_dim) def forward( diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 27fa8974e..3351fbc67 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -23,8 +23,8 @@ Usage: ./transducer/export.py \ --exp-dir ./transducer/exp \ --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 26 \ - --avg 12 + --epoch 34 \ + --avg 11 It will generate a file exp_dir/pretrained.pt @@ -66,7 +66,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=26, + default=34, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -74,7 +74,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=12, + default=11, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -119,10 +119,9 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, # decoder params "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, + "num_decoder_layers": 2, "decoder_hidden_dim": 512, "env_info": get_env_info(), } @@ -140,7 +139,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -150,7 +148,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -199,7 +196,6 @@ def main(): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer/joiner.py b/egs/librispeech/ASR/transducer/joiner.py index 0422f8a6f..2ef3f1de6 100644 --- a/egs/librispeech/ASR/transducer/joiner.py +++ b/egs/librispeech/ASR/transducer/joiner.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F class Joiner(nn.Module): @@ -48,7 +47,7 @@ class Joiner(nn.Module): # Now decoder_out is (N, 1, U, C) logit = encoder_out + decoder_out - logit = F.relu(logit) + logit = torch.tanh(logit) output = self.output_linear(logit) diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index cb9afd8a2..fa0b2dd68 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -49,7 +49,7 @@ class Transducer(nn.Module): decoder: It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain - two attributes: `blank_id` and `sos_id`. + one attribute: `blank_id`. joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains @@ -58,7 +58,6 @@ class Transducer(nn.Module): super().__init__() assert isinstance(encoder, EncoderInterface) assert hasattr(decoder, "blank_id") - assert hasattr(decoder, "sos_id") self.encoder = encoder self.decoder = decoder @@ -97,8 +96,7 @@ class Transducer(nn.Module): y_lens = row_splits[1:] - row_splits[:-1] blank_id = self.decoder.blank_id - sos_id = self.decoder.sos_id - sos_y = add_sos(y, sos_id=sos_id) + sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 4cf4fd4a7..f27938de6 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -116,10 +116,9 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, # decoder params "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, + "num_decoder_layers": 2, "decoder_hidden_dim": 512, "env_info": get_env_info(), } @@ -137,7 +136,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -147,7 +145,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -213,7 +210,6 @@ def main(): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(f"{params}") diff --git a/egs/librispeech/ASR/transducer/test_conformer.py b/egs/librispeech/ASR/transducer/test_conformer.py index 5d941d98a..9529e9c59 100755 --- a/egs/librispeech/ASR/transducer/test_conformer.py +++ b/egs/librispeech/ASR/transducer/test_conformer.py @@ -36,7 +36,6 @@ def test_conformer(): nhead=8, dim_feedforward=2048, num_encoder_layers=12, - use_feat_batchnorm=True, ) N = 3 T = 100 diff --git a/egs/librispeech/ASR/transducer/test_decoder.py b/egs/librispeech/ASR/transducer/test_decoder.py index 44c6eb6db..f0a7aa9cc 100755 --- a/egs/librispeech/ASR/transducer/test_decoder.py +++ b/egs/librispeech/ASR/transducer/test_decoder.py @@ -29,7 +29,6 @@ from decoder import Decoder def test_decoder(): vocab_size = 3 blank_id = 0 - sos_id = 2 embedding_dim = 128 num_layers = 2 hidden_dim = 6 @@ -41,7 +40,6 @@ def test_decoder(): vocab_size=vocab_size, embedding_dim=embedding_dim, blank_id=blank_id, - sos_id=sos_id, num_layers=num_layers, hidden_dim=hidden_dim, output_dim=output_dim, diff --git a/egs/librispeech/ASR/transducer/test_transducer.py b/egs/librispeech/ASR/transducer/test_transducer.py index bd4f2c188..15aa3b330 100755 --- a/egs/librispeech/ASR/transducer/test_transducer.py +++ b/egs/librispeech/ASR/transducer/test_transducer.py @@ -39,7 +39,6 @@ def test_transducer(): # decoder params vocab_size = 3 blank_id = 0 - sos_id = 2 embedding_dim = 128 num_layers = 2 @@ -51,14 +50,12 @@ def test_transducer(): nhead=8, dim_feedforward=2048, num_encoder_layers=12, - use_feat_batchnorm=True, ) decoder = Decoder( vocab_size=vocab_size, embedding_dim=embedding_dim, blank_id=blank_id, - sos_id=sos_id, num_layers=num_layers, hidden_dim=output_dim, output_dim=output_dim, diff --git a/egs/librispeech/ASR/transducer/test_transformer.py b/egs/librispeech/ASR/transducer/test_transformer.py index 8f4585504..bb68c22be 100755 --- a/egs/librispeech/ASR/transducer/test_transformer.py +++ b/egs/librispeech/ASR/transducer/test_transformer.py @@ -36,7 +36,6 @@ def test_transformer(): nhead=8, dim_feedforward=2048, num_encoder_layers=12, - use_feat_batchnorm=True, ) N = 3 T = 100 diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index 5d0b2d33a..dcb75609c 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -23,7 +23,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" ./transducer/train.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 35 \ --start-epoch 0 \ --exp-dir transducer/exp \ --full-libri 1 \ @@ -92,7 +92,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=30, + default=35, help="Number of epochs to train.", ) @@ -171,15 +171,10 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - attention_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - - weight_decay: The weight_decay for the optimizer. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -201,13 +196,11 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, # decoder params "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, + "num_decoder_layers": 2, "decoder_hidden_dim": 512, # parameters for Noam - "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), } @@ -227,7 +220,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -237,7 +229,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -575,7 +566,6 @@ def run(rank, world_size, args): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -599,7 +589,6 @@ def run(rank, world_size, args): model_size=params.attention_dim, factor=params.lr_factor, warm_step=params.warm_step, - weight_decay=params.weight_decay, ) if checkpoints and "optimizer" in checkpoints: diff --git a/egs/librispeech/ASR/transducer/transformer.py b/egs/librispeech/ASR/transducer/transformer.py index 814290264..e851dcc32 100644 --- a/egs/librispeech/ASR/transducer/transformer.py +++ b/egs/librispeech/ASR/transducer/transformer.py @@ -39,7 +39,6 @@ class Transformer(EncoderInterface): dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, ) -> None: """ Args: @@ -65,13 +64,8 @@ class Transformer(EncoderInterface): If True, use pre-layer norm; False to use post-layer norm. vgg_frontend: True to use vgg style frontend for subsampling. - use_feat_batchnorm: - True to use batchnorm for the input layer. """ super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) self.num_features = num_features self.output_dim = output_dim @@ -131,11 +125,6 @@ class Transformer(EncoderInterface): - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)