mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Make epsilon in BasicNorm learnable, optionally.
This commit is contained in:
parent
1962fe298b
commit
fc873cc50d
@ -56,7 +56,10 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
self.out_norm = BasicNorm(odim)
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
|
# needed.
|
||||||
|
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||||
# constrain mean of output to be close to zero.
|
# constrain mean of output to be close to zero.
|
||||||
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6)
|
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6)
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
@ -361,42 +364,45 @@ class BasicNorm(torch.nn.Module):
|
|||||||
|
|
||||||
So the idea is to introduce this large constant value as an explicit
|
So the idea is to introduce this large constant value as an explicit
|
||||||
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
||||||
doesn't have to do this trick.
|
doesn't have to do this trick. We make the "eps" learnable.
|
||||||
|
|
||||||
We also introduce a learned scaling factor on the output; and we
|
|
||||||
remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not
|
|
||||||
that useful unless the LayerNorm immediately follows a nonlinearity).
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
num_channels: the number of channels, e.g. 512.
|
||||||
channel_dim: the axis/dimension corresponding to the channel,
|
channel_dim: the axis/dimension corresponding to the channel,
|
||||||
interprted as an offset from the input's ndim if negative.
|
interprted as an offset from the input's ndim if negative.
|
||||||
shis is NOT the num_channels; it should typically be one of
|
shis is NOT the num_channels; it should typically be one of
|
||||||
{-2, -1, 0, 1, 2, 3}.
|
{-2, -1, 0, 1, 2, 3}.
|
||||||
initial_eps: the initial "epsilon" that we add as ballast in:
|
eps: the initial "epsilon" that we add as ballast in:
|
||||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||||
Note: our epsilon is actually large, but we keep the name
|
Note: our epsilon is actually large, but we keep the name
|
||||||
to indicate the connection with normal LayerNorm.
|
to indicate the connection with conventional LayerNorm.
|
||||||
|
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||||
speed: a scaling factor that can be interpreted as scaling the learning
|
at the initial value.
|
||||||
rate for this module. CAUTION: the default value of 10.0 intended to be
|
eps_speed: a constant that determines how fast "eps" learns;
|
||||||
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
|
with Adam and variants, this should probably be >= 1,
|
||||||
If you are using SGD you would probably have to set `speed` to
|
e.g. 5.0. For SGD and variants, probably a value less than one,
|
||||||
a value less than one, or the training would be unstable.
|
like 0.1, would be suitable, to prevent instability.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
channel_dim: int = -1, # CAUTION: see documentation.
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
eps: float = 0.25):
|
eps: float = 0.25,
|
||||||
|
learn_eps: bool = True,
|
||||||
|
eps_speed: float = 5.0):
|
||||||
super(BasicNorm, self).__init__()
|
super(BasicNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.eps = eps
|
self.eps_speed = eps_speed
|
||||||
|
if learn_eps:
|
||||||
|
self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach())
|
||||||
|
else:
|
||||||
|
self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach())
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5
|
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
|
||||||
|
(self.eps * self.eps_speed).exp()) ** -0.5
|
||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
|
||||||
|
@ -1129,4 +1129,5 @@ if __name__ == '__main__':
|
|||||||
seq_len = 20
|
seq_len = 20
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64))
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
warmup_mode=True)
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed",
|
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user