mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert some recent changes that may not have been helpful.
This commit is contained in:
parent
43f2a8d50b
commit
3d6ee443e3
@ -466,8 +466,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
channel_dim: int = -1, # CAUTION: see documentation.
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
eps: float = 0.25,
|
eps: float = 0.25,
|
||||||
learn_eps: bool = True,
|
learn_eps: bool = True,
|
||||||
eps_min: float = -2.0,
|
eps_min: float = -3.0,
|
||||||
eps_max: float = 2.0,
|
eps_max: float = 3.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(BasicNorm, self).__init__()
|
super(BasicNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
@ -487,8 +487,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max)
|
eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max)
|
||||||
eps = eps.exp()
|
eps = eps.exp()
|
||||||
scales = (
|
scales = (
|
||||||
(torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) /
|
(torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps)
|
||||||
(1.0 + eps)
|
# / (1.0 + eps)
|
||||||
) ** -0.5
|
) ** -0.5
|
||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
|||||||
@ -573,8 +573,9 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.feed_forward2(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
src = self.balancer(src)
|
src = self.balancer(src)
|
||||||
|
src = self.norm_final(src)
|
||||||
|
|
||||||
delta = self.norm_final(src - src_orig)
|
delta = src - src_orig
|
||||||
|
|
||||||
src = src_orig + delta * self.get_bypass_scale(src.shape[1])
|
src = src_orig + delta * self.get_bypass_scale(src.shape[1])
|
||||||
src = self.whiten(src)
|
src = self.whiten(src)
|
||||||
@ -1820,14 +1821,13 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
||||||
ConvNeXt(layer2_channels))
|
ConvNeXt(layer2_channels),
|
||||||
|
BasicNorm(layer2_channels,
|
||||||
|
channel_dim=1))
|
||||||
|
|
||||||
|
|
||||||
cur_width = (in_channels - 1) // 2
|
cur_width = (in_channels - 1) // 2
|
||||||
|
|
||||||
self.norm1 = BasicNorm(layer2_channels * cur_width,
|
|
||||||
channel_dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
self.conv2 = nn.Sequential(
|
self.conv2 = nn.Sequential(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
@ -1883,10 +1883,6 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.convnext1(x)
|
x = self.convnext1(x)
|
||||||
|
|
||||||
(batch_size, layer2_channels, num_frames, cur_width) = x.shape
|
|
||||||
x = x.permute(0, 2, 1, 3).reshape(batch_size, num_frames, layer2_channels * cur_width)
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = x.reshape(batch_size, num_frames, layer2_channels, cur_width).permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.convnext2(x)
|
x = self.convnext2(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user