mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add another balancer to ZipformerEncoderLayer, prior to output.
This commit is contained in:
parent
0c3530a6fd
commit
da0623aa7f
@ -452,20 +452,25 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||
|
||||
self.norm_final = BasicNorm(embed_dim)
|
||||
self.norm = BasicNorm(embed_dim)
|
||||
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = Balancer(
|
||||
self.balancer1 = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
min_abs=1.0, max_abs=6.0,
|
||||
min_abs=1.0, max_abs=4.0,
|
||||
)
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
self.balancer2 = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
min_abs=0.5, max_abs=2.0,
|
||||
)
|
||||
|
||||
|
||||
def remove_attention_weights(self):
|
||||
self.self_attn_weights = None
|
||||
@ -571,12 +576,15 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
src = self.balancer(src)
|
||||
src = self.norm_final(src)
|
||||
src = self.balancer1(src)
|
||||
src = self.norm(src)
|
||||
|
||||
bypass_scale = self.get_bypass_scale(src.shape[1])
|
||||
src = src * bypass_scale + src_orig * (1.0 - bypass_scale)
|
||||
# the next line equivalent to: src = src * bypass_scale + src_orig *
|
||||
# (1.0 - bypass_scale), but more memory efficient for backprop.
|
||||
src = src_orig + (src - src_orig) * bypass_scale
|
||||
|
||||
src = self.balancer2(src)
|
||||
src = self.whiten(src)
|
||||
|
||||
return src, attn_weights
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user