mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Remove batchnorm (#147)
* Remove batch normalization * Minor fixes * Fix typo * Fix comments * Add assertion for use_feat_batchnorm
This commit is contained in:
parent
95af039733
commit
a183d5bfd7
@ -15,10 +15,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -56,7 +55,7 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
use_feat_batchnorm: Union[float, bool] = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -75,6 +74,9 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
|
use_conv_batchnorm = True
|
||||||
|
if isinstance(use_feat_batchnorm, float):
|
||||||
|
use_conv_batchnorm = False
|
||||||
encoder_layer = ConformerEncoderLayer(
|
encoder_layer = ConformerEncoderLayer(
|
||||||
d_model,
|
d_model,
|
||||||
nhead,
|
nhead,
|
||||||
@ -82,6 +84,7 @@ class Conformer(Transformer):
|
|||||||
dropout,
|
dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
normalize_before,
|
normalize_before,
|
||||||
|
use_conv_batchnorm,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
@ -154,6 +157,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
|
use_conv_batchnorm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
@ -174,7 +178,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(
|
||||||
|
d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm
|
||||||
|
)
|
||||||
|
|
||||||
self.norm_ff_macaron = nn.LayerNorm(
|
self.norm_ff_macaron = nn.LayerNorm(
|
||||||
d_model
|
d_model
|
||||||
@ -843,12 +849,17 @@ class ConvolutionModule(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, channels: int, kernel_size: int, bias: bool = True
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
use_batchnorm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
self.use_batchnorm = use_batchnorm
|
||||||
|
|
||||||
self.pointwise_conv1 = nn.Conv1d(
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
@ -867,7 +878,8 @@ class ConvolutionModule(nn.Module):
|
|||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(channels)
|
if self.use_batchnorm:
|
||||||
|
self.norm = nn.BatchNorm1d(channels)
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -897,7 +909,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.activation(self.norm(x))
|
if self.use_batchnorm:
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
@ -17,7 +17,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -172,8 +171,12 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
- use_feat_batchnorm: Normalization for the input features, can be a
|
||||||
input features.
|
boolean indicating whether to do batch
|
||||||
|
normalization, or a float which means just scaling
|
||||||
|
the input features with this float value.
|
||||||
|
If given a float value, we will remove batchnorm
|
||||||
|
layer in `ConvolutionModule` as well.
|
||||||
|
|
||||||
- attention_dim: Hidden dim for multi-head attention model.
|
- attention_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
|
@ -14,9 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -42,7 +41,7 @@ class Transformer(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
use_feat_batchnorm: Union[float, bool] = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -72,10 +71,13 @@ class Transformer(nn.Module):
|
|||||||
True to use vgg style frontend for subsampling.
|
True to use vgg style frontend for subsampling.
|
||||||
use_feat_batchnorm:
|
use_feat_batchnorm:
|
||||||
True to use batchnorm for the input layer.
|
True to use batchnorm for the input layer.
|
||||||
|
Float value to scale the input layer.
|
||||||
|
False to do nothing.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_feat_batchnorm = use_feat_batchnorm
|
self.use_feat_batchnorm = use_feat_batchnorm
|
||||||
if use_feat_batchnorm:
|
assert isinstance(use_feat_batchnorm, (float, bool))
|
||||||
|
if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm:
|
||||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
@ -179,10 +181,15 @@ class Transformer(nn.Module):
|
|||||||
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
||||||
It is None if `supervision` is None.
|
It is None if `supervision` is None.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
if (
|
||||||
|
isinstance(self.use_feat_batchnorm, bool)
|
||||||
|
and self.use_feat_batchnorm
|
||||||
|
):
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
|
if isinstance(self.use_feat_batchnorm, float):
|
||||||
|
x *= self.use_feat_batchnorm
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
x, supervision
|
x, supervision
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user