mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Replace 1st ConvNorm2d with BasicNorm, remove the 2nd.
This commit is contained in:
parent
a0b2276f68
commit
dd7257f01b
@ -450,8 +450,8 @@ class BasicNormFunction(torch.autograd.Function):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||
ans = x * scales - bias
|
||||
ctx.save_for_backward(ans if store_output_for_backprop else x,
|
||||
scales, bias, eps)
|
||||
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x.detach(),
|
||||
scales.detach(), bias.detach(), eps.detach())
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@ -463,8 +463,6 @@ class BasicNormFunction(torch.autograd.Function):
|
||||
else:
|
||||
x = ans_or_x
|
||||
x = x.detach()
|
||||
bias = bias.detach()
|
||||
eps = eps.detach()
|
||||
x.requires_grad = True
|
||||
bias.requires_grad = True
|
||||
eps.requires_grad = True
|
||||
|
||||
@ -1813,8 +1813,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
|
||||
ConvNeXt(layer2_channels),
|
||||
ConvNorm2d(layer2_channels,
|
||||
kernel_size=(15, 7))) # (time, freq)
|
||||
BasicNorm(layer2_channels,
|
||||
channel_dim=1))
|
||||
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
@ -1832,10 +1832,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels),
|
||||
ConvNorm2d(layer3_channels,
|
||||
kernel_size=(15, 5))) # (time, freq)
|
||||
|
||||
ConvNeXt(layer3_channels))
|
||||
|
||||
out_height = (((in_channels - 1) // 2) - 1) // 2
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user