mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Remove batchnorm (#147)
* Remove batch normalization * Minor fixes * Fix typo * Fix comments * Add assertion for use_feat_batchnorm
This commit is contained in:
parent
95af039733
commit
a183d5bfd7
@ -15,10 +15,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -56,7 +55,7 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
use_feat_batchnorm: Union[float, bool] = 0.1,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
@ -75,6 +74,9 @@ class Conformer(Transformer):
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
use_conv_batchnorm = True
|
||||
if isinstance(use_feat_batchnorm, float):
|
||||
use_conv_batchnorm = False
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
@ -82,6 +84,7 @@ class Conformer(Transformer):
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
normalize_before,
|
||||
use_conv_batchnorm,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
self.normalize_before = normalize_before
|
||||
@ -154,6 +157,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
use_conv_batchnorm: bool = False,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
@ -174,7 +178,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
self.conv_module = ConvolutionModule(
|
||||
d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm
|
||||
)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
@ -843,12 +849,17 @@ class ConvolutionModule(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
use_batchnorm: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
self.use_batchnorm = use_batchnorm
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
@ -867,7 +878,8 @@ class ConvolutionModule(nn.Module):
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
if self.use_batchnorm:
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -897,7 +909,9 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
if self.use_batchnorm:
|
||||
x = self.norm(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
|
@ -17,7 +17,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -172,8 +171,12 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
||||
input features.
|
||||
- use_feat_batchnorm: Normalization for the input features, can be a
|
||||
boolean indicating whether to do batch
|
||||
normalization, or a float which means just scaling
|
||||
the input features with this float value.
|
||||
If given a float value, we will remove batchnorm
|
||||
layer in `ConvolutionModule` as well.
|
||||
|
||||
- attention_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
|
@ -14,9 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -42,7 +41,7 @@ class Transformer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
use_feat_batchnorm: Union[float, bool] = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -72,10 +71,13 @@ class Transformer(nn.Module):
|
||||
True to use vgg style frontend for subsampling.
|
||||
use_feat_batchnorm:
|
||||
True to use batchnorm for the input layer.
|
||||
Float value to scale the input layer.
|
||||
False to do nothing.
|
||||
"""
|
||||
super().__init__()
|
||||
self.use_feat_batchnorm = use_feat_batchnorm
|
||||
if use_feat_batchnorm:
|
||||
assert isinstance(use_feat_batchnorm, (float, bool))
|
||||
if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm:
|
||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
||||
|
||||
self.num_features = num_features
|
||||
@ -179,10 +181,15 @@ class Transformer(nn.Module):
|
||||
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
||||
It is None if `supervision` is None.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
if (
|
||||
isinstance(self.use_feat_batchnorm, bool)
|
||||
and self.use_feat_batchnorm
|
||||
):
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
if isinstance(self.use_feat_batchnorm, float):
|
||||
x *= self.use_feat_batchnorm
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user