Regularize how we apply the min and max to the eps of BasicNorm

This commit is contained in:
Daniel Povey 2022-10-26 12:51:20 +08:00
parent a0507a83a5
commit bf37c7ca85
3 changed files with 22 additions and 2 deletions

View File

@ -360,7 +360,11 @@ class ConformerEncoderLayer(nn.Module):
delta = src - src_orig delta = src - src_orig
bypass_scale = self.bypass_scale bypass_scale = self.bypass_scale
if random.random() > 0.1: if self.training and random.random() < 0.25:
# with probability 0.25, in training mode, clamp bypass_scale to [
# 0.1, 1.0 ]; this will encourage it to learn parameters within this
# range by making parameters that are outside that range range
# noisy.
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0) bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
src = src_orig + delta * self.bypass_scale src = src_orig + delta * self.bypass_scale

View File

@ -159,7 +159,7 @@ class ScaledAdam(BatchedOptimizer):
eps=1.0e-08, eps=1.0e-08,
param_min_rms=1.0e-05, param_min_rms=1.0e-05,
param_max_rms=3.0, param_max_rms=3.0,
scalar_max=5.0, scalar_max=10.0,
size_update_period=4, size_update_period=4,
clipping_update_period=100, clipping_update_period=100,
): ):

View File

@ -348,6 +348,8 @@ class BasicNorm(torch.nn.Module):
to indicate the connection with conventional LayerNorm. to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value. at the initial value.
eps_min: float
eps_max: float
""" """
def __init__( def __init__(
@ -356,6 +358,8 @@ class BasicNorm(torch.nn.Module):
channel_dim: int = -1, # CAUTION: see documentation. channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25, eps: float = 0.25,
learn_eps: bool = True, learn_eps: bool = True,
eps_min: float = -3.0,
eps_max: float = 3.0,
) -> None: ) -> None:
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
@ -364,9 +368,21 @@ class BasicNorm(torch.nn.Module):
self.eps = nn.Parameter(torch.tensor(eps).log().detach()) self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else: else:
self.register_buffer("eps", torch.tensor(eps).log().detach()) self.register_buffer("eps", torch.tensor(eps).log().detach())
self.eps_min = eps_min
self.eps_max = eps_max
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
eps = self.eps
if self.training and random.random() < 0.25:
# with probability 0.25, in training mode, clamp eps between the min
# and max; this will encourage it to learn parameters within the
# allowed range by making parameters that are outside the allowed
# range noisy.
# gradients to allow the parameter to get back into the allowed
# region if it happens to exit it.
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
scales = ( scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp() + self.eps.exp()