Replace 1st ConvNorm2d with BasicNorm, remove the 2nd.

This commit is contained in:
Daniel Povey 2022-12-22 16:50:52 +08:00
parent a0b2276f68
commit dd7257f01b
2 changed files with 5 additions and 10 deletions

View File

@ -450,8 +450,8 @@ class BasicNormFunction(torch.autograd.Function):
bias = bias.unsqueeze(-1) bias = bias.unsqueeze(-1)
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
ans = x * scales - bias ans = x * scales - bias
ctx.save_for_backward(ans if store_output_for_backprop else x, ctx.save_for_backward(ans.detach() if store_output_for_backprop else x.detach(),
scales, bias, eps) scales.detach(), bias.detach(), eps.detach())
return ans return ans
@staticmethod @staticmethod
@ -463,8 +463,6 @@ class BasicNormFunction(torch.autograd.Function):
else: else:
x = ans_or_x x = ans_or_x
x = x.detach() x = x.detach()
bias = bias.detach()
eps = eps.detach()
x.requires_grad = True x.requires_grad = True
bias.requires_grad = True bias.requires_grad = True
eps.requires_grad = True eps.requires_grad = True

View File

@ -1813,8 +1813,8 @@ class Conv2dSubsampling(nn.Module):
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
ConvNeXt(layer2_channels), ConvNeXt(layer2_channels),
ConvNorm2d(layer2_channels, BasicNorm(layer2_channels,
kernel_size=(15, 7))) # (time, freq) channel_dim=1))
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
@ -1832,10 +1832,7 @@ class Conv2dSubsampling(nn.Module):
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels), ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels), ConvNeXt(layer3_channels))
ConvNorm2d(layer3_channels,
kernel_size=(15, 5))) # (time, freq)
out_height = (((in_channels - 1) // 2) - 1) // 2 out_height = (((in_channels - 1) // 2) - 1) // 2