Bug fix; remove BasicNorm; add one more ConvNeXt layer.

This commit is contained in:
Daniel Povey 2022-12-17 16:11:54 +08:00
parent 744dca1c9b
commit 96daf7a00f

View File

@ -40,6 +40,7 @@ from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt, penalize_abs_values_gt,
softmax, softmax,
caching_eval,
ScheduledFloat, ScheduledFloat,
FloatLike, FloatLike,
limit_param_value, limit_param_value,
@ -1699,8 +1700,8 @@ class ConvNeXt(nn.Module):
min_positive=0.3, min_positive=0.3,
max_positive=1.0, max_positive=1.0,
min_abs=0.75, min_abs=0.75,
max_abs=5.0, max_abs=5.0)
min_prob=0.25)
self.activation = SwooshL() self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d( self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels, in_channels=hidden_channels,
@ -1795,9 +1796,7 @@ class Conv2dSubsampling(nn.Module):
self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels), self.convnext1 = nn.Sequential(ConvNeXt(layer2_channels),
ConvNeXt(layer2_channels), ConvNeXt(layer2_channels),
ConvNeXt(layer2_channels), ConvNeXt(layer2_channels))
BasicNorm(layer2_channels,
channel_dim=1))
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv2d( nn.Conv2d(
@ -1809,8 +1808,7 @@ class Conv2dSubsampling(nn.Module):
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels), ConvNeXt(layer3_channels),
BasicNorm(layer3_channels, ConvNeXt(layer3_channels))
channel_dim=1))
out_height = (((in_channels - 1) // 2) - 1) // 2 out_height = (((in_channels - 1) // 2) - 1) // 2