diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index be7130a97..f96ef34a9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -442,7 +442,7 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim, dropout) - self.nonlin_attention_module = NonlinAttentionModule(embed_dim, + self.nonlin_attention = NonlinAttention(embed_dim, hidden_channels=embed_dim // 4) @@ -461,13 +461,35 @@ class ZipformerEncoderLayer(nn.Module): min_positive=0.45, max_positive=0.55, min_abs=1.0, max_abs=4.0, ) + + + + + + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of AttentionSqueezeModule + self.balancer_as = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + # balancer for output of feedforward2, prevent it from staying too # small. give this a very small probability, even at the start of # training, it's to fix a rare problem and it's OK to fix it slowly. self.balancer_ff2 = Balancer( embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=ScheduledFloat((0.0, 0.0), (8000.0, 0.5), default=0.0), + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), max_abs=2.0, prob=0.05, ) @@ -550,7 +572,7 @@ class ZipformerEncoderLayer(nn.Module): ) # else rely on the ones passed in - # use different heads for nonlin_attention_module and attention_squeeze, depending + # use different heads for nonlin_attention and attention_squeeze, depending # whether this module has its on self_attn_weights submodule or is borrowing # attention weights from another one. head_offset = 0 if self.self_attn_weights is not None else 2 @@ -569,14 +591,15 @@ class ZipformerEncoderLayer(nn.Module): selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) if torch.jit.is_scripting() or use_self_attn: - src = src + self.nonlin_attention_module(src, - selected_attn_weights[0:1]) + src = src + self.balancer_na(self.nonlin_attention(src, + selected_attn_weights[0:1])) src = src + self.feed_forward1(src) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze(src, selected_attn_weights[1:2]) + src = src + self.balancer_as( + self.attention_squeeze(src, selected_attn_weights[1:2])) if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn( @@ -1359,14 +1382,6 @@ class AttentionSqueeze(nn.Module): self.out_proj = ScaledLinear(hidden_dim, embed_dim, bias=False, initial_scale=0.05) - self.out_balancer = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - def forward(self, x: Tensor, attn_weights: Tensor): @@ -1402,7 +1417,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = x * scales x = self.activation(x) # Identity only. For diagnostics. x = self.out_proj(x) - x = self.out_balancer(x) return x @@ -1443,7 +1457,7 @@ class FeedforwardModule(nn.Module): return x -class NonlinAttentionModule(nn.Module): +class NonlinAttention(nn.Module): """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed from the attention module) in place of actual convolution. We also took out the second nonlinearity, the one after the attention mechanism. @@ -1467,7 +1481,7 @@ class NonlinAttentionModule(nn.Module): # because we noticed that well-trained instances of this module have abs-value before the sigmoid # starting from about 3, and poorly-trained instances of the module have smaller abs values # before the sigmoid. - self.balancer1 = Balancer( + self.balancer = Balancer( hidden_channels, channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), @@ -1493,13 +1507,6 @@ class NonlinAttentionModule(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) - self.balancer2 = Balancer( - channels, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - def forward(self, x: Tensor, @@ -1521,7 +1528,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) s = x[..., hidden_channels:] x = x[..., :hidden_channels] - s = self.balancer1(s) + s = self.balancer(s) s = self.tanh(s) s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) @@ -1541,8 +1548,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = self.out_proj(x) x = self.whiten2(x) - x = self.balancer2(x) - return x