Make epsilon in BasicNorm learnable, optionally.

This commit is contained in:
Daniel Povey 2022-03-15 17:00:17 +08:00
parent 1962fe298b
commit fc873cc50d
3 changed files with 28 additions and 21 deletions

View File

@ -56,7 +56,10 @@ class Conv2dSubsampling(nn.Module):
DoubleSwish(), DoubleSwish(),
) )
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out_norm = BasicNorm(odim) # set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain mean of output to be close to zero. # constrain mean of output to be close to zero.
self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6)
self._reset_parameters() self._reset_parameters()
@ -361,42 +364,45 @@ class BasicNorm(torch.nn.Module):
So the idea is to introduce this large constant value as an explicit So the idea is to introduce this large constant value as an explicit
parameter, that takes the role of the "eps" in LayerNorm, so the network parameter, that takes the role of the "eps" in LayerNorm, so the network
doesn't have to do this trick. doesn't have to do this trick. We make the "eps" learnable.
We also introduce a learned scaling factor on the output; and we
remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not
that useful unless the LayerNorm immediately follows a nonlinearity).
Args: Args:
num_channels: the number of channels, e.g. 512.
channel_dim: the axis/dimension corresponding to the channel, channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative. interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}. {-2, -1, 0, 1, 2, 3}.
initial_eps: the initial "epsilon" that we add as ballast in: eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5 scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name Note: our epsilon is actually large, but we keep the name
to indicate the connection with normal LayerNorm. to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
speed: a scaling factor that can be interpreted as scaling the learning at the initial value.
rate for this module. CAUTION: the default value of 10.0 intended to be eps_speed: a constant that determines how fast "eps" learns;
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. with Adam and variants, this should probably be >= 1,
If you are using SGD you would probably have to set `speed` to e.g. 5.0. For SGD and variants, probably a value less than one,
a value less than one, or the training would be unstable. like 0.1, would be suitable, to prevent instability.
""" """
def __init__(self, def __init__(self,
num_channels: int, num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation. channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25): eps: float = 0.25,
learn_eps: bool = True,
eps_speed: float = 5.0):
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.eps = eps self.eps_speed = eps_speed
if learn_eps:
self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach())
else:
self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach())
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5 scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) +
(self.eps * self.eps_speed).exp()) ** -0.5
return x * scales return x * scales

View File

@ -1129,4 +1129,5 @@ if __name__ == '__main__':
seq_len = 20 seq_len = 20
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
f = c(torch.randn(batch_size, seq_len, feature_dim), f = c(torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64)) torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup_mode=True)

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed", default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved