diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 2da9c6445..0811406e3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1506,23 +1506,20 @@ class ConvolutionModule(nn.Module): """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, channels: int, kernel_size: int, ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, + self.in_proj = LinearWithAuxLoss( + channels, 2 * channels, + aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)) ) - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + + # after in_proj we put x through a gated linear unit (nn.functional.glu). # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, # but sometimes, for some reason, for layer 0 the rms ends up being very large, # between 50 and 100 for different channels. This will cause very peaky and @@ -1536,8 +1533,8 @@ class ConvolutionModule(nn.Module): # it will be in a better position to start learning something, i.e. to latch onto # the correct range. self.deriv_balancer1 = ActivationBalancer( - 2 * channels, - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + 2 * channels, channel_dim=-1, + max_abs=10.0, min_positive=0.05, max_positive=1.0 ) self.depthwise_conv = nn.Conv1d( @@ -1547,7 +1544,7 @@ class ConvolutionModule(nn.Module): stride=1, padding=(kernel_size - 1) // 2, groups=channels, - bias=bias, + bias=True, ) self.deriv_balancer2 = ActivationBalancer( @@ -1563,13 +1560,9 @@ class ConvolutionModule(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, + self.out_proj = LinearWithAuxLoss( + channels, channels, + aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)), initial_scale=0.05, ) @@ -1589,15 +1582,14 @@ class ConvolutionModule(nn.Module): Tensor: Output tensor (#time, batch, channels). """ + + x = self.in_proj(x) # (time, batch, 2*channels) + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=-1) # (time, batch, channels) + # exchange the temporal dimension and the feature dimension x = x.permute(1, 2, 0) # (#batch, channels, time). - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - if src_key_padding_mask is not None: x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) @@ -1605,15 +1597,12 @@ class ConvolutionModule(nn.Module): x = self.depthwise_conv(x) x = self.deriv_balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + x = self.activation(x) + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) - x = x.transpose(1, 2) - x = self.whiten(x) # (batch, time, channel) - x = x.transpose(1, 2) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - x = x.permute(2, 0, 1) # (time, batch, channel) return x @@ -1732,7 +1721,9 @@ class Conv2dSubsampling(nn.Module): self.squeeze_excite = SqueezeExcite1d(out_height * layer3_channels, bottleneck_channels) - self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, + aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))) + self.dropout = nn.Dropout(dropout)