mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Make epsilon in BasicNorm learnable, optionally.
This commit is contained in:
parent
1962fe298b
commit
fc873cc50d
@ -56,7 +56,10 @@ class Conv2dSubsampling(nn.Module):
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
self.out_norm = BasicNorm(odim)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||
# constrain mean of output to be close to zero.
|
||||
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6)
|
||||
self._reset_parameters()
|
||||
@ -361,42 +364,45 @@ class BasicNorm(torch.nn.Module):
|
||||
|
||||
So the idea is to introduce this large constant value as an explicit
|
||||
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
||||
doesn't have to do this trick.
|
||||
|
||||
We also introduce a learned scaling factor on the output; and we
|
||||
remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not
|
||||
that useful unless the LayerNorm immediately follows a nonlinearity).
|
||||
|
||||
doesn't have to do this trick. We make the "eps" learnable.
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interprted as an offset from the input's ndim if negative.
|
||||
shis is NOT the num_channels; it should typically be one of
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
initial_eps: the initial "epsilon" that we add as ballast in:
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with normal LayerNorm.
|
||||
|
||||
speed: a scaling factor that can be interpreted as scaling the learning
|
||||
rate for this module. CAUTION: the default value of 10.0 intended to be
|
||||
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
|
||||
If you are using SGD you would probably have to set `speed` to
|
||||
a value less than one, or the training would be unstable.
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
eps_speed: a constant that determines how fast "eps" learns;
|
||||
with Adam and variants, this should probably be >= 1,
|
||||
e.g. 5.0. For SGD and variants, probably a value less than one,
|
||||
like 0.1, would be suitable, to prevent instability.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
eps: float = 0.25):
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
eps_speed: float = 5.0):
|
||||
super(BasicNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.eps = eps
|
||||
self.eps_speed = eps_speed
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach())
|
||||
else:
|
||||
self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach())
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5
|
||||
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
|
||||
(self.eps * self.eps_speed).exp()) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
|
@ -1129,4 +1129,5 @@ if __name__ == '__main__':
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64))
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup_mode=True)
|
||||
|
@ -110,7 +110,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed",
|
||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
Loading…
x
Reference in New Issue
Block a user