mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Draft of 0mean changes..
This commit is contained in:
parent
fc873cc50d
commit
261d7602a7
@ -60,8 +60,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# itself has learned scale, so the extra degree of freedom is not
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
# needed.
|
# needed.
|
||||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||||
# constrain mean of output to be close to zero.
|
# constrain median of output to be close to zero.
|
||||||
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6)
|
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
@ -536,7 +536,7 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, channel_dim: int,
|
def __init__(self, channel_dim: int,
|
||||||
min_positive: float = 0.05,
|
min_positive: float = 0.05,
|
||||||
max_positive: float = 1.0,
|
max_positive: float = 0.95,
|
||||||
max_factor: float = 0.01,
|
max_factor: float = 0.01,
|
||||||
min_abs: float = 0.2,
|
min_abs: float = 0.2,
|
||||||
max_abs: float = 100.0):
|
max_abs: float = 100.0):
|
||||||
|
@ -88,7 +88,7 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -179,6 +179,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.pre_norm_final = Identity()
|
self.pre_norm_final = Identity()
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
|
self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
|
||||||
@ -227,7 +230,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
|
||||||
src = self.norm_final(self.pre_norm_final(src))
|
src = self.balancer(self.norm_final(self.pre_norm_final(src)))
|
||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
@ -862,7 +865,8 @@ class ConvolutionModule(nn.Module):
|
|||||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
# the correct range.
|
# the correct range.
|
||||||
self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0)
|
self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0,
|
||||||
|
min_positive=0.05, max_positive=1.0)
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
self.depthwise_conv = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
@ -874,7 +878,8 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deriv_balancer2 = DerivBalancer(channel_dim=1)
|
self.deriv_balancer2 = DerivBalancer(channel_dim=1,
|
||||||
|
min_positive=0.05, max_positive=1.0)
|
||||||
|
|
||||||
# Shape: (channels, 1), broadcasts with (batch, channel, time).
|
# Shape: (channels, 1), broadcasts with (batch, channel, time).
|
||||||
self.activation = SwishOffset()
|
self.activation = SwishOffset()
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale",
|
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user