diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6d8256c26..7f9c7d7fa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -32,8 +32,7 @@ from scaling import ( ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, - Identity, - _diag, + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. penalize_abs_values_gt, softmax, ScheduledFloat, @@ -212,7 +211,7 @@ class Zipformer(EncoderInterface): for i in range(len(z)): if i <= 1 or z[i-1] <= z[i]: skip_layers.append(None) - skip_modules.append(nn.Identity()) + skip_modules.append(Identity()) else: # TEMP for j in range(i-2, -1, -1): @@ -1300,11 +1299,12 @@ class AttentionSqueeze(nn.Module): self.bottleneck_balancer = ActivationBalancer( bottleneck_dim, channel_dim=-1, min_positive=0.05, max_positive=0.95, - min_abs=0.2, - max_abs=1.0, + min_abs=0.05, + max_abs=2.0, max_factor=0.02, min_prob=0.1, ) + self.activation = DoubleSwish() # in bottleneck # the next two balancers are only to stop parameter-magnitude 'drift': we have # too many degrees of freedom for the scales of the various activations. @@ -1313,15 +1313,15 @@ class AttentionSqueeze(nn.Module): # put no limitation on the signs (so: min_positive=0, max_positive=1). self.scale_balancer = ActivationBalancer( embed_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, + min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, - min_prob=0.025, + min_prob=0.05, ) self.activation_balancer = ActivationBalancer( embed_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, + min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, - min_prob=0.025, + min_prob=0.05, ) self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) @@ -1356,14 +1356,14 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # -> (num_heads, batch_size, seq_len, head_dim) bottleneck = torch.matmul(attn_weights, bottleneck) bottleneck = self.bottleneck_balancer(bottleneck) + bottleneck = self.activation(bottleneck) bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim) bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim) scales = self.from_bottleneck_proj(bottleneck) - scales = self.scale_balancer(scales) - x = self.activation_balancer(x) - x = self.in_proj(x) + x = self.activation_balancer(x) + scales = self.scale_balancer(scales) x = x * scales x = self.out_proj(x) x = self.out_whiten(x) @@ -1399,7 +1399,8 @@ class FeedforwardModule(nn.Module): class NonlinAttentionModule(nn.Module): """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. Args: channels (int): The number of channels of conv layers. @@ -1410,24 +1411,9 @@ class NonlinAttentionModule(nn.Module): ) -> None: super().__init__() - # to_scale and to_value are analogous to pointwise_conv1 in ConvolutionModule - # we make them separate because we need an extra degree of freedom for the - # scale, as the attention weights are constrained to sum to one so cannot - # provide the degree of freedom for the scale of the features before - # self.activation(). - self.to_scale = nn.Linear(channels, channels, bias=True) - self.to_value = nn.Linear(channels, channels, bias=True) - - - # deriv_balancer corresponds to deriv_balancer2 in ConvolutionModule - self.deriv_balancer = ActivationBalancer( - channels, channel_dim=-1, - min_positive=0.05, max_positive=1.0, - max_abs=20.0, - ) - - self.activation = DoubleSwish() + self.in_proj = nn.Linear(channels, 2 * channels, bias=True) + self.activation = Identity() # for diagnostics. self.out_proj = ScaledLinear(channels, channels, bias=True, initial_scale=0.05) @@ -1443,8 +1429,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) Returns: a Tensor with the same shape as x """ - s = self.to_scale(x) - v = self.to_value(x) + v, s = self.in_proj(x).chunk(2, dim=-1) + if self.training and random.random() < 0.02: # prevent the inputs to the sigmoid from getting very large (this is # hopefully quite a rare phenomenon, so we are giving this path a @@ -1463,8 +1449,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = torch.matmul(attn_weights, x) # now x: (num_heads, batch_size, seq_len, head_dim) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - x = self.deriv_balancer(x) - x = self.activation(x) + + x = self.activation(x) # diagnostics only, it's the identity. x = self.out_proj(x) return x