mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Replace 1st ConvNorm2d with BasicNorm, remove the 2nd.
This commit is contained in:
parent
a0b2276f68
commit
dd7257f01b
@ -450,8 +450,8 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
bias = bias.unsqueeze(-1)
|
bias = bias.unsqueeze(-1)
|
||||||
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||||
ans = x * scales - bias
|
ans = x * scales - bias
|
||||||
ctx.save_for_backward(ans if store_output_for_backprop else x,
|
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x.detach(),
|
||||||
scales, bias, eps)
|
scales.detach(), bias.detach(), eps.detach())
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -463,8 +463,6 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
else:
|
else:
|
||||||
x = ans_or_x
|
x = ans_or_x
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
bias = bias.detach()
|
|
||||||
eps = eps.detach()
|
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
bias.requires_grad = True
|
bias.requires_grad = True
|
||||||
eps.requires_grad = True
|
eps.requires_grad = True
|
||||||
|
|||||||
@ -1813,8 +1813,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
||||||
ConvNeXt(layer2_channels),
|
ConvNeXt(layer2_channels),
|
||||||
ConvNorm2d(layer2_channels,
|
BasicNorm(layer2_channels,
|
||||||
kernel_size=(15, 7))) # (time, freq)
|
channel_dim=1))
|
||||||
|
|
||||||
|
|
||||||
self.conv2 = nn.Sequential(
|
self.conv2 = nn.Sequential(
|
||||||
@ -1832,10 +1832,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
||||||
ConvNeXt(layer3_channels),
|
ConvNeXt(layer3_channels),
|
||||||
ConvNeXt(layer3_channels),
|
ConvNeXt(layer3_channels))
|
||||||
ConvNorm2d(layer3_channels,
|
|
||||||
kernel_size=(15, 5))) # (time, freq)
|
|
||||||
|
|
||||||
|
|
||||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user