Apply limit on BasicNorm.eps more effectively using limit_param_value; add final norm to Zipformer.

This commit is contained in:
Daniel Povey 2022-12-23 15:59:51 +08:00
parent 049174722f
commit 2e0f4de8ff
2 changed files with 5 additions and 9 deletions

View File

@ -481,16 +481,10 @@ class BasicNorm(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
eps = self.eps
if self.training and random.random() < 0.25:
# with probability 0.25, in training mode, clamp eps between the min
# and max; this will encourage it to learn parameters within the
# allowed range by making parameters that are outside the allowed
# range noisy.
# gradients to allow the parameter to get back into the allowed eps = self.eps
# region if it happens to exit it. if self.training:
eps = eps.clamp(min=self.eps_min, max=self.eps_max) eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max)
eps = eps.exp() eps = eps.exp()
scales = ( scales = (
(torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) / (torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) /

View File

@ -216,6 +216,7 @@ class Zipformer(EncoderInterface):
encoder_dim[-1], encoder_dim[-1],
downsample=output_downsampling_factor, downsample=output_downsampling_factor,
dropout=dropout) dropout=dropout)
self.norm = BasicNorm(num_channels=encoder_dim[-1])
def _init_skip_modules(self): def _init_skip_modules(self):
@ -357,6 +358,7 @@ class Zipformer(EncoderInterface):
lengths = (lengths + 1) // 2 lengths = (lengths + 1) // 2
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = self.norm(x)
return x, lengths return x, lengths