From 8f841e5b2b198bad3f2d7845a367d44352219a0a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 6 Dec 2022 11:11:28 +0800 Subject: [PATCH] Add another balancer for NonlinAttentionModule. --- .../pruned_transducer_stateless7/zipformer.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 56fefbd90..08095b77c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1477,7 +1477,7 @@ class NonlinAttentionModule(nn.Module): # because we noticed that well-trained instances of this module have abs-value before the sigmoid # starting from about 3, and poorly-trained instances of the module have smaller abs values # before the sigmoid. - self.balancer = ActivationBalancer( + self.balancer1 = ActivationBalancer( hidden_channels // ratio, channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), max_positive=1.0, @@ -1491,6 +1491,19 @@ class NonlinAttentionModule(nn.Module): bias=True, initial_scale=0.05) + # Have very tight limits on min_positive and max_positive so that it beomes + # close to zero mean, as we found that large mean offsets after the + # multiplication are associated with poor convergence. + # We don't need min_abs and max_abs limits because sharing the in_proj + # between the sigmoid-input and activations dictates the scale of the + # activations at this point. The code applies those anyway, it's not optional + # right now, so just use the default values. + self.balancer2 = ActivationBalancer( + hidden_channels // ratio, channel_dim=-1, + min_positive=0.4, max_positive=0.6, + min_abs=0.5, + ) + self.whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(5.0), prob=(0.025, 0.25), @@ -1518,7 +1531,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) s = x[..., hidden_channels:] x = x[..., :hidden_channels] - s = self.balancer(s) + s = self.balancer1(s) s = self.tanh(s) s = s.unsqueeze(-1).expand(-1, -1, -1, self.ratio).reshape(seq_len, batch_size, @@ -1536,6 +1549,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # now x: (num_heads, batch_size, seq_len, head_dim) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + x = self.balancer2(x) x = self.whiten(x) x = self.out_proj(x) return x