Revert some recent changes that may not have been helpful.

This commit is contained in:
Daniel Povey 2022-12-24 21:17:43 +08:00
parent 43f2a8d50b
commit 3d6ee443e3
2 changed files with 9 additions and 13 deletions

View File

@ -466,8 +466,8 @@ class BasicNorm(torch.nn.Module):
channel_dim: int = -1, # CAUTION: see documentation. channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25, eps: float = 0.25,
learn_eps: bool = True, learn_eps: bool = True,
eps_min: float = -2.0, eps_min: float = -3.0,
eps_max: float = 2.0, eps_max: float = 3.0,
) -> None: ) -> None:
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
@ -487,8 +487,8 @@ class BasicNorm(torch.nn.Module):
eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max) eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max)
eps = eps.exp() eps = eps.exp()
scales = ( scales = (
(torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) / (torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps)
(1.0 + eps) # / (1.0 + eps)
) ** -0.5 ) ** -0.5
return x * scales return x * scales

View File

@ -573,8 +573,9 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward2(src) src = src + self.feed_forward2(src)
src = self.balancer(src) src = self.balancer(src)
src = self.norm_final(src)
delta = self.norm_final(src - src_orig) delta = src - src_orig
src = src_orig + delta * self.get_bypass_scale(src.shape[1]) src = src_orig + delta * self.get_bypass_scale(src.shape[1])
src = self.whiten(src) src = self.whiten(src)
@ -1820,14 +1821,13 @@ class Conv2dSubsampling(nn.Module):
) )
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
ConvNeXt(layer2_channels)) ConvNeXt(layer2_channels),
BasicNorm(layer2_channels,
channel_dim=1))
cur_width = (in_channels - 1) // 2 cur_width = (in_channels - 1) // 2
self.norm1 = BasicNorm(layer2_channels * cur_width,
channel_dim=-1)
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv2d( nn.Conv2d(
@ -1883,10 +1883,6 @@ class Conv2dSubsampling(nn.Module):
x = self.conv1(x) x = self.conv1(x)
x = self.convnext1(x) x = self.convnext1(x)
(batch_size, layer2_channels, num_frames, cur_width) = x.shape
x = x.permute(0, 2, 1, 3).reshape(batch_size, num_frames, layer2_channels * cur_width)
x = self.norm1(x)
x = x.reshape(batch_size, num_frames, layer2_channels, cur_width).permute(0, 2, 1, 3)
x = self.conv2(x) x = self.conv2(x)
x = self.convnext2(x) x = self.convnext2(x)