Remove batchnorm (#147)

* Remove batch normalization

* Minor fixes

* Fix typo

* Fix comments

* Add assertion for use_feat_batchnorm
This commit is contained in:
Wei Kang 2021-12-14 08:20:03 +08:00 committed by GitHub
parent 95af039733
commit a183d5bfd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 15 deletions

View File

@ -15,10 +15,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -56,7 +55,7 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False, use_feat_batchnorm: Union[float, bool] = 0.1,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -75,6 +74,9 @@ class Conformer(Transformer):
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
use_conv_batchnorm = True
if isinstance(use_feat_batchnorm, float):
use_conv_batchnorm = False
encoder_layer = ConformerEncoderLayer( encoder_layer = ConformerEncoderLayer(
d_model, d_model,
nhead, nhead,
@ -82,6 +84,7 @@ class Conformer(Transformer):
dropout, dropout,
cnn_module_kernel, cnn_module_kernel,
normalize_before, normalize_before,
use_conv_batchnorm,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before self.normalize_before = normalize_before
@ -154,6 +157,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
use_conv_batchnorm: bool = False,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
@ -174,7 +178,9 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm
)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(
d_model d_model
@ -843,12 +849,17 @@ class ConvolutionModule(nn.Module):
""" """
def __init__( def __init__(
self, channels: int, kernel_size: int, bias: bool = True self,
channels: int,
kernel_size: int,
bias: bool = True,
use_batchnorm: bool = False,
) -> None: ) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
self.use_batchnorm = use_batchnorm
self.pointwise_conv1 = nn.Conv1d( self.pointwise_conv1 = nn.Conv1d(
channels, channels,
@ -867,7 +878,8 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.BatchNorm1d(channels) if self.use_batchnorm:
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
@ -897,7 +909,9 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.activation(self.norm(x)) if self.use_batchnorm:
x = self.norm(x)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -17,7 +17,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
@ -172,8 +171,12 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the - use_feat_batchnorm: Normalization for the input features, can be a
input features. boolean indicating whether to do batch
normalization, or a float which means just scaling
the input features with this float value.
If given a float value, we will remove batchnorm
layer in `ConvolutionModule` as well.
- attention_dim: Hidden dim for multi-head attention model. - attention_dim: Hidden dim for multi-head attention model.

View File

@ -14,9 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -42,7 +41,7 @@ class Transformer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False, use_feat_batchnorm: Union[float, bool] = 0.1,
) -> None: ) -> None:
""" """
Args: Args:
@ -72,10 +71,13 @@ class Transformer(nn.Module):
True to use vgg style frontend for subsampling. True to use vgg style frontend for subsampling.
use_feat_batchnorm: use_feat_batchnorm:
True to use batchnorm for the input layer. True to use batchnorm for the input layer.
Float value to scale the input layer.
False to do nothing.
""" """
super().__init__() super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm: assert isinstance(use_feat_batchnorm, (float, bool))
if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features) self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features self.num_features = num_features
@ -179,10 +181,15 @@ class Transformer(nn.Module):
memory_key_padding_mask for the decoder. Its shape is (N, T). memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None. It is None if `supervision` is None.
""" """
if self.use_feat_batchnorm: if (
isinstance(self.use_feat_batchnorm, bool)
and self.use_feat_batchnorm
):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
if isinstance(self.use_feat_batchnorm, float):
x *= self.use_feat_batchnorm
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision x, supervision
) )