From a183d5bfd7f7c3ad16ed38c83a8a77933c085002 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 14 Dec 2021 08:20:03 +0800 Subject: [PATCH] Remove batchnorm (#147) * Remove batch normalization * Minor fixes * Fix typo * Fix comments * Add assertion for use_feat_batchnorm --- .../ASR/conformer_ctc/conformer.py | 28 ++++++++++++++----- egs/librispeech/ASR/conformer_ctc/train.py | 9 ++++-- .../ASR/conformer_ctc/transformer.py | 17 +++++++---- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index b19b94db1..871712a46 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -15,10 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import Tensor, nn @@ -56,7 +55,7 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, + use_feat_batchnorm: Union[float, bool] = 0.1, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -75,6 +74,9 @@ class Conformer(Transformer): self.encoder_pos = RelPositionalEncoding(d_model, dropout) + use_conv_batchnorm = True + if isinstance(use_feat_batchnorm, float): + use_conv_batchnorm = False encoder_layer = ConformerEncoderLayer( d_model, nhead, @@ -82,6 +84,7 @@ class Conformer(Transformer): dropout, cnn_module_kernel, normalize_before, + use_conv_batchnorm, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before @@ -154,6 +157,7 @@ class ConformerEncoderLayer(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, + use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( @@ -174,7 +178,9 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm + ) self.norm_ff_macaron = nn.LayerNorm( d_model @@ -843,12 +849,17 @@ class ConvolutionModule(nn.Module): """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, + channels: int, + kernel_size: int, + bias: bool = True, + use_batchnorm: bool = False, ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 + self.use_batchnorm = use_batchnorm self.pointwise_conv1 = nn.Conv1d( channels, @@ -867,7 +878,8 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.norm = nn.BatchNorm1d(channels) + if self.use_batchnorm: + self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, @@ -897,7 +909,9 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) + if self.use_batchnorm: + x = self.norm(x) + x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 2fbf17a62..c1fa814c0 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -17,7 +17,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse import logging from pathlib import Path @@ -172,8 +171,12 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. - attention_dim: Hidden dim for multi-head attention model. diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index f93914aaa..00ca027a7 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -14,9 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -42,7 +41,7 @@ class Transformer(nn.Module): dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, + use_feat_batchnorm: Union[float, bool] = 0.1, ) -> None: """ Args: @@ -72,10 +71,13 @@ class Transformer(nn.Module): True to use vgg style frontend for subsampling. use_feat_batchnorm: True to use batchnorm for the input layer. + Float value to scale the input layer. + False to do nothing. """ super().__init__() self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: + assert isinstance(use_feat_batchnorm, (float, bool)) + if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm: self.feat_batchnorm = nn.BatchNorm1d(num_features) self.num_features = num_features @@ -179,10 +181,15 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + if isinstance(self.use_feat_batchnorm, float): + x *= self.use_feat_batchnorm encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision )