Remove batch normalization

This commit is contained in:
pkufool 2021-12-10 14:30:33 +08:00
parent 95af039733
commit db924dcef5
3 changed files with 37 additions and 14 deletions

View File

@ -15,10 +15,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -56,7 +55,7 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False, use_feat_batchnorm: Union[float, bool] = 0.1,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -75,6 +74,9 @@ class Conformer(Transformer):
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
use_conv_batchnorm = True
if isinstance(use_feat_batchnorm, float):
use_conv_batchnorm = False
encoder_layer = ConformerEncoderLayer( encoder_layer = ConformerEncoderLayer(
d_model, d_model,
nhead, nhead,
@ -82,6 +84,7 @@ class Conformer(Transformer):
dropout, dropout,
cnn_module_kernel, cnn_module_kernel,
normalize_before, normalize_before,
use_conv_batchnorm,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before self.normalize_before = normalize_before
@ -154,6 +157,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
use_conv_batchnorm: bool = False,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
@ -174,7 +178,9 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm
)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(
d_model d_model
@ -843,12 +849,17 @@ class ConvolutionModule(nn.Module):
""" """
def __init__( def __init__(
self, channels: int, kernel_size: int, bias: bool = True self,
channels: int,
kernel_size: int,
bias: bool = True,
use_batchnorm: bool = False,
) -> None: ) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
self.use_batchnorm = use_batchnorm
self.pointwise_conv1 = nn.Conv1d( self.pointwise_conv1 = nn.Conv1d(
channels, channels,
@ -867,7 +878,8 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.BatchNorm1d(channels) if self.use_bathnorm:
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
@ -897,7 +909,9 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.activation(self.norm(x)) if self.use_batchnorm:
x = self.norm(x)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -17,7 +17,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
@ -172,8 +171,12 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the - use_feat_batchnorm: Normalization for the input features, can be a
input features. boolean idicating whether to do batch
normalization, or a float which means just scaling
the input features with this float value.
If given a float value, we will remove batchnorm
layer in `ConvolutionModule` as well.
- attention_dim: Hidden dim for multi-head attention model. - attention_dim: Hidden dim for multi-head attention model.

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -42,7 +41,7 @@ class Transformer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False, use_feat_batchnorm: Union[float, bool] = 0.1,
) -> None: ) -> None:
""" """
Args: Args:
@ -72,10 +71,12 @@ class Transformer(nn.Module):
True to use vgg style frontend for subsampling. True to use vgg style frontend for subsampling.
use_feat_batchnorm: use_feat_batchnorm:
True to use batchnorm for the input layer. True to use batchnorm for the input layer.
Float value to scale the input layer.
False to do nothing.
""" """
super().__init__() super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm: if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features) self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features self.num_features = num_features
@ -179,10 +180,15 @@ class Transformer(nn.Module):
memory_key_padding_mask for the decoder. Its shape is (N, T). memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None. It is None if `supervision` is None.
""" """
if self.use_feat_batchnorm: if (
isinstance(self.use_feat_batchnorm, bool)
and self.use_feat_batchnorm
):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
if isinstance(self.use_feat_batchnorm, float):
x *= 0.1
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision x, supervision
) )