From 14c93add507982306f5a478cd144e0e32e0f970d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 27 Dec 2021 16:01:10 +0800 Subject: [PATCH] Remove batchnorm, weight decay, and SOS from transducer conformer encoder (#155) * Remove batchnorm, weight decay, and SOS. * Make --context-size configurable. * Update results. --- .../run-pretrained-transducer-stateless.yml | 26 +++++++++---------- README.md | 2 +- egs/librispeech/ASR/RESULTS.md | 18 +++++-------- egs/librispeech/ASR/transducer/decode.py | 2 +- egs/librispeech/ASR/transducer/export.py | 2 +- egs/librispeech/ASR/transducer/pretrained.py | 2 +- egs/librispeech/ASR/transducer/train.py | 2 +- .../ASR/transducer_stateless/conformer.py | 16 +++++------- .../ASR/transducer_stateless/decode.py | 16 +++++++----- .../ASR/transducer_stateless/decoder.py | 5 ++-- .../ASR/transducer_stateless/export.py | 12 ++++++--- .../ASR/transducer_stateless/joiner.py | 3 +-- .../ASR/transducer_stateless/pretrained.py | 11 +++++--- .../ASR/transducer_stateless/train.py | 21 +++++++-------- .../ASR/transducer_stateless/transformer.py | 11 -------- 15 files changed, 70 insertions(+), 79 deletions(-) diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 026d3967c..3bbd4c49b 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -74,11 +74,11 @@ jobs: mkdir tmp cd tmp git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22 + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27 cd .. tree tmp - soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav - ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav + soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav + ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav - name: Run greedy search decoding shell: bash @@ -87,11 +87,11 @@ jobs: cd egs/librispeech/ASR ./transducer_stateless/pretrained.py \ --method greedy_search \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav - name: Run beam search decoding shell: bash @@ -101,8 +101,8 @@ jobs: ./transducer_stateless/pretrained.py \ --method beam_search \ --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav + --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \ + ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav diff --git a/README.md b/README.md index f0a678839..ff93e8fad 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ The best WER using beam search with beam size 4 is: | | test-clean | test-other | |-----|------------|------------| -| WER | 2.92 | 7.37 | +| WER | 2.83 | 7.19 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index aab2b61e0..8ff535932 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -4,7 +4,7 @@ #### Conformer encoder + embedding decoder -Using commit `fb6a57e9e01dd8aae2af2a6b4568daad8bc8ab32`. +Using commit `TODO`. Conformer encoder + non-current decoder. The decoder contains only an embedding layer and a Conv1d (with kernel size 2). @@ -13,12 +13,8 @@ The WERs are | | test-clean | test-other | comment | |---------------------------|------------|------------|------------------------------------------| -| greedy search | 2.99 | 7.52 | --epoch 20, --avg 10, --max-duration 100 | -| beam search (beam size 2) | 2.95 | 7.43 | | -| beam search (beam size 3) | 2.94 | 7.37 | | -| beam search (beam size 4) | 2.92 | 7.37 | | -| beam search (beam size 5) | 2.93 | 7.38 | | -| beam search (beam size 8) | 2.92 | 7.38 | | +| greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 | +| beam search (beam size 4) | 2.83 | 7.19 | | The training command for reproducing is given below: @@ -36,12 +32,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" ``` The tensorboard training log can be found at - + The decoding command is: ``` -epoch=20 -avg=10 +epoch=29 +avg=13 ## greedy search ./transducer_stateless/decode.py \ @@ -64,7 +60,7 @@ avg=10 #### Conformer encoder + LSTM decoder -Using commit `TODO`. +Using commit `8187d6236c2926500da5ee854f758e621df803cc`. Conformer encoder + LSTM decoder. diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index ef0992618..990513ed9 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -396,7 +396,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 3351fbc67..5a5db30c4 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -194,7 +194,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index f27938de6..1db2df648 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -208,7 +208,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index dcb75609c..903ba8491 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -564,7 +564,7 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 245aaa428..81d7708f9 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -56,7 +56,6 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -69,7 +68,6 @@ class Conformer(Transformer): dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -107,11 +105,6 @@ class Conformer(Transformer): - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.norm = nn.BatchNorm1d(channels) + self.norm = nn.LayerNorm(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, @@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 51bebed5a..e5987b75e 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=20, + default=29, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=10, + default=13, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -114,6 +114,13 @@ def get_parser(): help="Used only when --decoding-method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) parser.add_argument( "--max-sym-per-frame", type=int, @@ -136,9 +143,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -156,7 +160,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -393,6 +396,7 @@ def main(): if params.decoding_method == "beam_search": params.suffix += f"-beam-{params.beam_size}" else: + params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index cedbc937e..dca084477 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -20,13 +20,14 @@ import torch.nn.functional as F class Decoder(nn.Module): - """This class implements the stateless decoder from the following paper: + """This class modifies the stateless decoder from the following paper: RNN-transducer with stateless prediction network https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 It removes the recurrent connection from the decoder, i.e., the prediction - network. + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. TODO: Implement https://arxiv.org/pdf/2109.07513.pdf """ diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index a877b5067..641555bdb 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -104,6 +104,14 @@ def get_parser(): """, ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -119,9 +127,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 0422f8a6f..2ef3f1de6 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F class Joiner(nn.Module): @@ -48,7 +47,7 @@ class Joiner(nn.Module): # Now decoder_out is (N, 1, U, C) logit = encoder_out + decoder_out - logit = F.relu(logit) + logit = torch.tanh(logit) output = self.output_linear(logit) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 6a6626371..e5dba8f0e 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -110,6 +110,13 @@ def get_parser(): help="Used only when --method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) parser.add_argument( "--max-sym-per-frame", type=int, @@ -135,9 +142,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -154,7 +158,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index a2bf4700c..694ebf1d5 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -130,6 +130,14 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -171,15 +179,10 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - attention_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - - weight_decay: The weight_decay for the optimizer. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -201,11 +204,7 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "context_size": 2, # tri-gram # parameters for Noam - "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), } @@ -225,7 +224,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -568,7 +566,7 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and are defined in local/train_bpe_model.py + # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -593,7 +591,6 @@ def run(rank, world_size, args): model_size=params.attention_dim, factor=params.lr_factor, warm_step=params.warm_step, - weight_decay=params.weight_decay, ) if checkpoints and "optimizer" in checkpoints: diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py index 814290264..e851dcc32 100644 --- a/egs/librispeech/ASR/transducer_stateless/transformer.py +++ b/egs/librispeech/ASR/transducer_stateless/transformer.py @@ -39,7 +39,6 @@ class Transformer(EncoderInterface): dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, ) -> None: """ Args: @@ -65,13 +64,8 @@ class Transformer(EncoderInterface): If True, use pre-layer norm; False to use post-layer norm. vgg_frontend: True to use vgg style frontend for subsampling. - use_feat_batchnorm: - True to use batchnorm for the input layer. """ super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) self.num_features = num_features self.output_dim = output_dim @@ -131,11 +125,6 @@ class Transformer(EncoderInterface): - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)