Make epsilon in BasicNorm learnable, optionally.

This commit is contained in:
Daniel Povey 2022-03-15 17:00:17 +08:00
parent 1962fe298b
commit fc873cc50d
3 changed files with 28 additions and 21 deletions

View File

@ -56,7 +56,10 @@ class Conv2dSubsampling(nn.Module):
DoubleSwish(),
)
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = BasicNorm(odim)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain mean of output to be close to zero.
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6)
self._reset_parameters()
@ -361,42 +364,45 @@ class BasicNorm(torch.nn.Module):
So the idea is to introduce this large constant value as an explicit
parameter, that takes the role of the "eps" in LayerNorm, so the network
doesn't have to do this trick.
We also introduce a learned scaling factor on the output; and we
remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not
that useful unless the LayerNorm immediately follows a nonlinearity).
doesn't have to do this trick. We make the "eps" learnable.
Args:
num_channels: the number of channels, e.g. 512.
channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
initial_eps: the initial "epsilon" that we add as ballast in:
eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with normal LayerNorm.
speed: a scaling factor that can be interpreted as scaling the learning
rate for this module. CAUTION: the default value of 10.0 intended to be
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
If you are using SGD you would probably have to set `speed` to
a value less than one, or the training would be unstable.
Note: our epsilon is actually large, but we keep the name
to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
eps_speed: a constant that determines how fast "eps" learns;
with Adam and variants, this should probably be >= 1,
e.g. 5.0. For SGD and variants, probably a value less than one,
like 0.1, would be suitable, to prevent instability.
"""
def __init__(self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25):
eps: float = 0.25,
learn_eps: bool = True,
eps_speed: float = 5.0):
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps = eps
self.eps_speed = eps_speed
if learn_eps:
self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach())
else:
self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach())
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
(self.eps * self.eps_speed).exp()) ** -0.5
return x * scales

View File

@ -1129,4 +1129,5 @@ if __name__ == '__main__':
seq_len = 20
# Just make sure the forward pass runs.
f = c(torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64))
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup_mode=True)

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed",
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved