Add another balancer to ZipformerEncoderLayer, prior to output.

This commit is contained in:
Daniel Povey 2022-12-30 14:35:03 +08:00
parent 0c3530a6fd
commit da0623aa7f

View File

@ -452,20 +452,25 @@ class ZipformerEncoderLayer(nn.Module):
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
self.norm_final = BasicNorm(embed_dim) self.norm = BasicNorm(embed_dim)
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
# try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer1 = Balancer(
self.balancer = Balancer(
embed_dim, channel_dim=-1, embed_dim, channel_dim=-1,
min_positive=0.45, max_positive=0.55, min_positive=0.45, max_positive=0.55,
min_abs=1.0, max_abs=6.0, min_abs=1.0, max_abs=4.0,
) )
self.whiten = Whiten(num_groups=1, self.whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(4.0, ratio=3.0), whitening_limit=_whitening_schedule(4.0, ratio=3.0),
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.01) grad_scale=0.01)
self.balancer2 = Balancer(
embed_dim, channel_dim=-1,
min_positive=0.45, max_positive=0.55,
min_abs=0.5, max_abs=2.0,
)
def remove_attention_weights(self): def remove_attention_weights(self):
self.self_attn_weights = None self.self_attn_weights = None
@ -571,12 +576,15 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.feed_forward2(src) src = src + self.feed_forward2(src)
src = self.balancer(src) src = self.balancer1(src)
src = self.norm_final(src) src = self.norm(src)
bypass_scale = self.get_bypass_scale(src.shape[1]) bypass_scale = self.get_bypass_scale(src.shape[1])
src = src * bypass_scale + src_orig * (1.0 - bypass_scale) # the next line equivalent to: src = src * bypass_scale + src_orig *
# (1.0 - bypass_scale), but more memory efficient for backprop.
src = src_orig + (src - src_orig) * bypass_scale
src = self.balancer2(src)
src = self.whiten(src) src = self.whiten(src)
return src, attn_weights return src, attn_weights