Remove batchnorm (#147)

* Remove batch normalization

* Minor fixes

* Fix typo

* Fix comments

* Add assertion for use_feat_batchnorm
This commit is contained in:
Wei Kang 2021-12-14 08:20:03 +08:00 committed by GitHub
parent 95af039733
commit a183d5bfd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 15 deletions

View File

@ -15,10 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
from torch import Tensor, nn
@ -56,7 +55,7 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
use_feat_batchnorm: Union[float, bool] = 0.1,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -75,6 +74,9 @@ class Conformer(Transformer):
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
use_conv_batchnorm = True
if isinstance(use_feat_batchnorm, float):
use_conv_batchnorm = False
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
@ -82,6 +84,7 @@ class Conformer(Transformer):
dropout,
cnn_module_kernel,
normalize_before,
use_conv_batchnorm,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
@ -154,6 +157,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
use_conv_batchnorm: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
@ -174,7 +178,9 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm
)
self.norm_ff_macaron = nn.LayerNorm(
d_model
@ -843,12 +849,17 @@ class ConvolutionModule(nn.Module):
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self,
channels: int,
kernel_size: int,
bias: bool = True,
use_batchnorm: bool = False,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.use_batchnorm = use_batchnorm
self.pointwise_conv1 = nn.Conv1d(
channels,
@ -867,7 +878,8 @@ class ConvolutionModule(nn.Module):
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
if self.use_batchnorm:
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
@ -897,7 +909,9 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
if self.use_batchnorm:
x = self.norm(x)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -17,7 +17,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
@ -172,8 +171,12 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- use_feat_batchnorm: Normalization for the input features, can be a
boolean indicating whether to do batch
normalization, or a float which means just scaling
the input features with this float value.
If given a float value, we will remove batchnorm
layer in `ConvolutionModule` as well.
- attention_dim: Hidden dim for multi-head attention model.

View File

@ -14,9 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -42,7 +41,7 @@ class Transformer(nn.Module):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
use_feat_batchnorm: Union[float, bool] = 0.1,
) -> None:
"""
Args:
@ -72,10 +71,13 @@ class Transformer(nn.Module):
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
Float value to scale the input layer.
False to do nothing.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
assert isinstance(use_feat_batchnorm, (float, bool))
if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
@ -179,10 +181,15 @@ class Transformer(nn.Module):
memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None.
"""
if self.use_feat_batchnorm:
if (
isinstance(self.use_feat_batchnorm, bool)
and self.use_feat_batchnorm
):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
if isinstance(self.use_feat_batchnorm, float):
x *= self.use_feat_batchnorm
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)