From e692e0b2285e72e82cfcc0aa9f7569ffd78d3683 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Mar 2023 17:39:01 +0800 Subject: [PATCH] Add balancer for keys --- .../pruned_transducer_stateless7/zipformer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 3ae6b3ce0..16939ab74 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1203,6 +1203,22 @@ class RelPositionMultiheadAttentionWeights(nn.Module): prob=(0.025, 0.25), grad_scale=0.025) + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer(key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=6, + min_abs=0.0, + max_abs=100.0, + prob=0.025) + # linear transformation for positional encoding. self.linear_pos = ScaledLinear(pos_dim, @@ -1256,7 +1272,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(k) # does nothing in the forward pass. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. p = self.copy_pos_query(p) # for diagnostics only, does nothing.