Add balancer for keys

This commit is contained in:
Daniel Povey 2023-03-07 17:39:01 +08:00
parent f59da65d82
commit e692e0b228

View File

@ -1203,6 +1203,22 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.025) grad_scale=0.025)
# add a balancer for the keys that runs with very small probability, and
# tries to enforce that all dimensions have mean around zero. The
# weights produced by this module are invariant to adding a constant to
# the keys, so the derivative of the bias is mathematically zero; but
# due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
# bias because the small numerical roundoff tends to have a non-random
# sign. This module is intended to prevent that. Use a very small
# probability; that should be suffixient to fix the problem.
self.balance_keys = Balancer(key_head_dim * num_heads,
channel_dim=-1,
min_positive=0.4,
max_positive=6,
min_abs=0.0,
max_abs=100.0,
prob=0.025)
# linear transformation for positional encoding. # linear transformation for positional encoding.
self.linear_pos = ScaledLinear(pos_dim, self.linear_pos = ScaledLinear(pos_dim,
@ -1256,7 +1272,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
q = self.copy_query(q) # for diagnostics only, does nothing. q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(k) # does nothing in the forward pass. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
p = self.copy_pos_query(p) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing.