diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 78fcac664..50a9db41a 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -60,8 +60,8 @@ class Conv2dSubsampling(nn.Module): # itself has learned scale, so the extra degree of freedom is not # needed. self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain mean of output to be close to zero. - self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) + # constrain median of output to be close to zero. + self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) self._reset_parameters() def _reset_parameters(self): @@ -536,7 +536,7 @@ class DerivBalancer(torch.nn.Module): """ def __init__(self, channel_dim: int, min_positive: float = 0.05, - max_positive: float = 1.0, + max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 8b229a234..cc1ae53a1 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -88,7 +88,7 @@ class Conformer(Transformer): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -179,6 +179,9 @@ class ConformerEncoderLayer(nn.Module): self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.dropout = nn.Dropout(dropout) @@ -227,7 +230,7 @@ class ConformerEncoderLayer(nn.Module): # feed forward module src = src + self.dropout(self.feed_forward(src)) - src = self.norm_final(self.pre_norm_final(src)) + src = self.balancer(self.norm_final(self.pre_norm_final(src))) return src @@ -862,7 +865,8 @@ class ConvolutionModule(nn.Module): # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) + self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.05, max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, @@ -874,7 +878,8 @@ class ConvolutionModule(nn.Module): bias=bias, ) - self.deriv_balancer2 = DerivBalancer(channel_dim=1) + self.deriv_balancer2 = DerivBalancer(channel_dim=1, + min_positive=0.05, max_positive=1.0) # Shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 2af306f94..41fdb4ef3 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved