diff --git a/.github/workflows/run-pretrained.yml b/.github/workflows/run-pretrained-conformer-ctc.yml similarity index 100% rename from .github/workflows/run-pretrained.yml rename to .github/workflows/run-pretrained-conformer-ctc.yml diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 7af2299a4..026d3967c 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-pre-trained-tranducer-stateless +name: run-pre-trained-trandsucer-stateless on: push: diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml new file mode 100644 index 000000000..f0ebddba3 --- /dev/null +++ b/.github/workflows/run-pretrained-transducer.yml @@ -0,0 +1,109 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-pre-trained-transducer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_pre_trained_transducer: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.9.dev20211101"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip pytest + # numpy 1.20.x does not support python 3.6 + pip install numpy==1.19 + pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + + python3 -m pip install git+https://github.com/lhotse-speech/lhotse + python3 -m pip install kaldifeat + # We are in ./icefall and there is a file: requirements.txt in it + pip install -r requirements.txt + + - name: Install graphviz + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Download pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + cd egs/librispeech/ASR + mkdir tmp + cd tmp + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23 + + cd .. + tree tmp + soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav + ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav + + - name: Run greedy search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./transducer/pretrained.py \ + --method greedy_search \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav + + - name: Run beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/librispeech/ASR + ./transducer/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav diff --git a/README.md b/README.md index 931fb0198..f0a678839 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ The best WER with greedy search is: | | test-clean | test-other | |-----|------------|------------| -| WER | 3.16 | 7.71 | +| WER | 3.07 | 7.51 | We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 317b1591a..aab2b61e0 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -2,7 +2,10 @@ ### LibriSpeech BPE training results (Transducer) -#### 2021-12-22 +#### Conformer encoder + embedding decoder + +Using commit `fb6a57e9e01dd8aae2af2a6b4568daad8bc8ab32`. + Conformer encoder + non-current decoder. The decoder contains only an embedding layer and a Conv1d (with kernel size 2). @@ -60,8 +63,8 @@ avg=10 ``` -#### 2021-12-17 -Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`. +#### Conformer encoder + LSTM decoder +Using commit `TODO`. Conformer encoder + LSTM decoder. @@ -69,9 +72,9 @@ The best WER is | | test-clean | test-other | |-----|------------|------------| -| WER | 3.16 | 7.71 | +| WER | 3.07 | 7.51 | -using `--epoch 26 --avg 12` with **greedy search**. +using `--epoch 34 --avg 11` with **greedy search**. The training command to reproduce the above WER is: @@ -80,19 +83,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" ./transducer/train.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 35 \ --start-epoch 0 \ --exp-dir transducer/exp-lr-2.5-full \ --full-libri 1 \ - --max-duration 250 \ + --max-duration 180 \ --lr-factor 2.5 ``` The decoding command is: ``` -epoch=26 -avg=12 +epoch=34 +avg=11 ./transducer/decode.py \ --epoch $epoch \ @@ -102,7 +105,7 @@ avg=12 --max-duration 100 ``` -You can find the tensorboard log at: +You can find the tensorboard log at: ### LibriSpeech BPE training results (Conformer-CTC) diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index dfc22fcf8..f45d06ce9 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -111,7 +111,6 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id - sos_id = model.decoder.sos_id device = model.device sos = torch.tensor([blank_id], device=device).reshape(1, 1) @@ -192,7 +191,7 @@ def beam_search( # Second, choose other labels for i, v in enumerate(log_prob.tolist()): - if i in (blank_id, sos_id): + if i == blank_id: continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v diff --git a/egs/librispeech/ASR/transducer/conformer.py b/egs/librispeech/ASR/transducer/conformer.py index 245aaa428..81d7708f9 100644 --- a/egs/librispeech/ASR/transducer/conformer.py +++ b/egs/librispeech/ASR/transducer/conformer.py @@ -56,7 +56,6 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -69,7 +68,6 @@ class Conformer(Transformer): dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -107,11 +105,6 @@ class Conformer(Transformer): - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.norm = nn.BatchNorm1d(channels) + self.norm = nn.LayerNorm(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, @@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 80b72a89f..ef0992618 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=26, + default=34, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=12, + default=11, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -129,10 +129,9 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, # decoder params "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, + "num_decoder_layers": 2, "decoder_hidden_dim": 512, "env_info": get_env_info(), } @@ -151,7 +150,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -161,7 +159,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -401,7 +398,6 @@ def main(): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py index 2f6bf4c07..7b529ac19 100644 --- a/egs/librispeech/ASR/transducer/decoder.py +++ b/egs/librispeech/ASR/transducer/decoder.py @@ -27,7 +27,6 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, - sos_id: int, num_layers: int, hidden_dim: int, output_dim: int, @@ -42,8 +41,6 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. - sos_id: - The ID of the SOS symbol. num_layers: Number of LSTM layers. hidden_dim: @@ -71,7 +68,6 @@ class Decoder(nn.Module): dropout=rnn_dropout, ) self.blank_id = blank_id - self.sos_id = sos_id self.output_linear = nn.Linear(hidden_dim, output_dim) def forward( diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 27fa8974e..3351fbc67 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -23,8 +23,8 @@ Usage: ./transducer/export.py \ --exp-dir ./transducer/exp \ --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 26 \ - --avg 12 + --epoch 34 \ + --avg 11 It will generate a file exp_dir/pretrained.pt @@ -66,7 +66,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=26, + default=34, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -74,7 +74,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=12, + default=11, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -119,10 +119,9 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, # decoder params "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, + "num_decoder_layers": 2, "decoder_hidden_dim": 512, "env_info": get_env_info(), } @@ -140,7 +139,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -150,7 +148,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -199,7 +196,6 @@ def main(): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/transducer/joiner.py b/egs/librispeech/ASR/transducer/joiner.py index 0422f8a6f..2ef3f1de6 100644 --- a/egs/librispeech/ASR/transducer/joiner.py +++ b/egs/librispeech/ASR/transducer/joiner.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F class Joiner(nn.Module): @@ -48,7 +47,7 @@ class Joiner(nn.Module): # Now decoder_out is (N, 1, U, C) logit = encoder_out + decoder_out - logit = F.relu(logit) + logit = torch.tanh(logit) output = self.output_linear(logit) diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index cb9afd8a2..fa0b2dd68 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -49,7 +49,7 @@ class Transducer(nn.Module): decoder: It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain - two attributes: `blank_id` and `sos_id`. + one attribute: `blank_id`. joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains @@ -58,7 +58,6 @@ class Transducer(nn.Module): super().__init__() assert isinstance(encoder, EncoderInterface) assert hasattr(decoder, "blank_id") - assert hasattr(decoder, "sos_id") self.encoder = encoder self.decoder = decoder @@ -97,8 +96,7 @@ class Transducer(nn.Module): y_lens = row_splits[1:] - row_splits[:-1] blank_id = self.decoder.blank_id - sos_id = self.decoder.sos_id - sos_y = add_sos(y, sos_id=sos_id) + sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 4cf4fd4a7..f27938de6 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -116,10 +116,9 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, # decoder params "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, + "num_decoder_layers": 2, "decoder_hidden_dim": 512, "env_info": get_env_info(), } @@ -137,7 +136,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -147,7 +145,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -213,7 +210,6 @@ def main(): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(f"{params}") diff --git a/egs/librispeech/ASR/transducer/test_conformer.py b/egs/librispeech/ASR/transducer/test_conformer.py index 5d941d98a..9529e9c59 100755 --- a/egs/librispeech/ASR/transducer/test_conformer.py +++ b/egs/librispeech/ASR/transducer/test_conformer.py @@ -36,7 +36,6 @@ def test_conformer(): nhead=8, dim_feedforward=2048, num_encoder_layers=12, - use_feat_batchnorm=True, ) N = 3 T = 100 diff --git a/egs/librispeech/ASR/transducer/test_decoder.py b/egs/librispeech/ASR/transducer/test_decoder.py index 44c6eb6db..f0a7aa9cc 100755 --- a/egs/librispeech/ASR/transducer/test_decoder.py +++ b/egs/librispeech/ASR/transducer/test_decoder.py @@ -29,7 +29,6 @@ from decoder import Decoder def test_decoder(): vocab_size = 3 blank_id = 0 - sos_id = 2 embedding_dim = 128 num_layers = 2 hidden_dim = 6 @@ -41,7 +40,6 @@ def test_decoder(): vocab_size=vocab_size, embedding_dim=embedding_dim, blank_id=blank_id, - sos_id=sos_id, num_layers=num_layers, hidden_dim=hidden_dim, output_dim=output_dim, diff --git a/egs/librispeech/ASR/transducer/test_transducer.py b/egs/librispeech/ASR/transducer/test_transducer.py index bd4f2c188..15aa3b330 100755 --- a/egs/librispeech/ASR/transducer/test_transducer.py +++ b/egs/librispeech/ASR/transducer/test_transducer.py @@ -39,7 +39,6 @@ def test_transducer(): # decoder params vocab_size = 3 blank_id = 0 - sos_id = 2 embedding_dim = 128 num_layers = 2 @@ -51,14 +50,12 @@ def test_transducer(): nhead=8, dim_feedforward=2048, num_encoder_layers=12, - use_feat_batchnorm=True, ) decoder = Decoder( vocab_size=vocab_size, embedding_dim=embedding_dim, blank_id=blank_id, - sos_id=sos_id, num_layers=num_layers, hidden_dim=output_dim, output_dim=output_dim, diff --git a/egs/librispeech/ASR/transducer/test_transformer.py b/egs/librispeech/ASR/transducer/test_transformer.py index 8f4585504..bb68c22be 100755 --- a/egs/librispeech/ASR/transducer/test_transformer.py +++ b/egs/librispeech/ASR/transducer/test_transformer.py @@ -36,7 +36,6 @@ def test_transformer(): nhead=8, dim_feedforward=2048, num_encoder_layers=12, - use_feat_batchnorm=True, ) N = 3 T = 100 diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index 5d0b2d33a..dcb75609c 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -23,7 +23,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" ./transducer/train.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 35 \ --start-epoch 0 \ --exp-dir transducer/exp \ --full-libri 1 \ @@ -92,7 +92,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=30, + default=35, help="Number of epochs to train.", ) @@ -171,15 +171,10 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - attention_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - - weight_decay: The weight_decay for the optimizer. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -201,13 +196,11 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, # decoder params "decoder_embedding_dim": 1024, - "num_decoder_layers": 4, + "num_decoder_layers": 2, "decoder_hidden_dim": 512, # parameters for Noam - "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), } @@ -227,7 +220,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -237,7 +229,6 @@ def get_decoder_model(params: AttributeDict): vocab_size=params.vocab_size, embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, - sos_id=params.sos_id, num_layers=params.num_decoder_layers, hidden_dim=params.decoder_hidden_dim, output_dim=params.encoder_out_dim, @@ -575,7 +566,6 @@ def run(rank, world_size, args): # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -599,7 +589,6 @@ def run(rank, world_size, args): model_size=params.attention_dim, factor=params.lr_factor, warm_step=params.warm_step, - weight_decay=params.weight_decay, ) if checkpoints and "optimizer" in checkpoints: diff --git a/egs/librispeech/ASR/transducer/transformer.py b/egs/librispeech/ASR/transducer/transformer.py index 814290264..e851dcc32 100644 --- a/egs/librispeech/ASR/transducer/transformer.py +++ b/egs/librispeech/ASR/transducer/transformer.py @@ -39,7 +39,6 @@ class Transformer(EncoderInterface): dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, ) -> None: """ Args: @@ -65,13 +64,8 @@ class Transformer(EncoderInterface): If True, use pre-layer norm; False to use post-layer norm. vgg_frontend: True to use vgg style frontend for subsampling. - use_feat_batchnorm: - True to use batchnorm for the input layer. """ super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) self.num_features = num_features self.output_dim = output_dim @@ -131,11 +125,6 @@ class Transformer(EncoderInterface): - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)